From ae52d9369468889a5f536254a9b4131321cff720 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 16 Jan 2026 13:09:41 +0800 Subject: [PATCH] support klein 4b models --- diffsynth/configs/model_configs.py | 7 ++++ diffsynth/models/flux2_dit.py | 67 +++++++++++++----------------- diffsynth/pipelines/flux2_image.py | 5 ++- 3 files changed, 39 insertions(+), 40 deletions(-) diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index eed58f8..cc23fb9 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -510,6 +510,13 @@ flux2_series = [ "model_name": "flux2_vae", "model_class": "diffsynth.models.flux2_vae.Flux2VAE", }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "3bde7b817fec8143028b6825a63180df", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20} + }, ] z_image_series = [ diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py index a08c579..316cf08 100644 --- a/diffsynth/models/flux2_dit.py +++ b/diffsynth/models/flux2_dit.py @@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module): class Flux2TimestepGuidanceEmbeddings(nn.Module): - def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): super().__init__() self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module): in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias ) - self.guidance_embedder = TimestepEmbedding( - in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias - ) + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + else: + self.guidance_embedder = None def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) - - time_guidance_emb = timesteps_emb + guidance_emb - - return time_guidance_emb + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb class Flux2Modulation(nn.Module): @@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module): axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), rope_theta: int = 2000, eps: float = 1e-6, + guidance_embeds: bool = True, ): super().__init__() self.out_channels = out_channels or in_channels @@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module): # 2. Combined timestep + guidance embedding self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( - in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, ) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) @@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module): txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, - ) -> Union[torch.Tensor]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): - Input `hidden_states`. - encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ + ): # 0. Handle input arguments if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() @@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module): # 1. Calculate timestep embedding and modulation parameters timestep = timestep.to(hidden_states.dtype) * 1000 - guidance = guidance.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 temb = self.time_guidance_embed(timestep, guidance) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 8b00469..e94d2c3 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -10,7 +10,7 @@ from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput -from transformers import AutoProcessor +from transformers import AutoProcessor, AutoTokenizer from ..models.flux2_text_encoder import Flux2TextEncoder from ..models.flux2_dit import Flux2DiT from ..models.flux2_vae import Flux2VAE @@ -53,11 +53,12 @@ class Flux2ImagePipeline(BasePipeline): # Fetch models pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") + pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder") pipe.dit = model_pool.fetch_model("flux2_dit") pipe.vae = model_pool.fetch_model("flux2_vae") if tokenizer_config is not None: tokenizer_config.download_if_necessary() - pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path) + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state()