support klein 4b models

This commit is contained in:
Artiprocher
2026-01-16 13:09:41 +08:00
parent 55e8346da3
commit ae52d93694
3 changed files with 39 additions and 40 deletions

View File

@@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module):
class Flux2TimestepGuidanceEmbeddings(nn.Module):
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
def __init__(
self,
in_channels: int = 256,
embedding_dim: int = 6144,
bias: bool = False,
guidance_embeds: bool = True,
):
super().__init__()
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
self.guidance_embedder = TimestepEmbedding(
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
else:
self.guidance_embedder = None
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
if guidance is not None and self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
else:
return timesteps_emb
class Flux2Modulation(nn.Module):
@@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module):
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
rope_theta: int = 2000,
eps: float = 1e-6,
guidance_embeds: bool = True,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module):
# 2. Combined timestep + guidance embedding
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
in_channels=timestep_guidance_channels,
embedding_dim=self.inner_dim,
bias=False,
guidance_embeds=guidance_embeds,
)
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
@@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module):
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
) -> Union[torch.Tensor]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
):
# 0. Handle input arguments
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
@@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module):
# 1. Calculate timestep embedding and modulation parameters
timestep = timestep.to(hidden_states.dtype) * 1000
guidance = guidance.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_guidance_embed(timestep, guidance)