diff --git a/README.md b/README.md index 2c7ef27..a590402 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ We believe that a well-developed open-source code framework can lower the thresh > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available. + - **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)). - **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)). @@ -321,7 +323,9 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/) | Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation | |-|-|-|-|-| -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| @@ -774,4 +778,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47 https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea - diff --git a/README_zh.md b/README_zh.md index 81a33e3..6a0e0ef 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,6 +33,8 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。 + - **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。 - **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。 @@ -321,7 +323,9 @@ FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/) |模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-| -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index eed58f8..c93f5e9 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -510,6 +510,28 @@ flux2_series = [ "model_name": "flux2_vae", "model_class": "diffsynth.models.flux2_vae.Flux2VAE", }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "3bde7b817fec8143028b6825a63180df", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20} + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "9195f3ea256fcd0ae6d929c203470754", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "8B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "39c6fc48f07bebecedbbaa971ff466c8", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24} + }, ] z_image_series = [ diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index e3b3329..b658866 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -1,4 +1,4 @@ -import torch, json +import torch, json, os from ..core import ModelConfig, load_state_dict from ..utils.controlnet import ControlNetInput from peft import LoraConfig, inject_adapter_in_model @@ -127,15 +127,29 @@ class DiffusionTrainingModule(torch.nn.Module): if model_id_with_origin_paths is not None: model_id_with_origin_paths = model_id_with_origin_paths.split(",") for model_id_with_origin_path in model_id_with_origin_paths: - model_id, origin_file_pattern = model_id_with_origin_path.split(":") vram_config = self.parse_vram_config( fp8=model_id_with_origin_path in fp8_models, offload=model_id_with_origin_path in offload_models, device=device ) - model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config)) + config = self.parse_path_or_model_id(model_id_with_origin_path) + model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config)) return model_configs + + def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None): + if model_id_with_origin_path is None: + return default_value + elif os.path.exists(model_id_with_origin_path): + return ModelConfig(path=model_id_with_origin_path) + else: + if ":" not in model_id_with_origin_path: + raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.") + split_id = model_id_with_origin_path.rfind(":") + model_id = model_id_with_origin_path[:split_id] + origin_file_pattern = model_id_with_origin_path[split_id + 1:] + return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern) + def switch_pipe_to_training_mode( self, diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py index a08c579..316cf08 100644 --- a/diffsynth/models/flux2_dit.py +++ b/diffsynth/models/flux2_dit.py @@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module): class Flux2TimestepGuidanceEmbeddings(nn.Module): - def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): super().__init__() self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module): in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias ) - self.guidance_embedder = TimestepEmbedding( - in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias - ) + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + else: + self.guidance_embedder = None def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) - - time_guidance_emb = timesteps_emb + guidance_emb - - return time_guidance_emb + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb class Flux2Modulation(nn.Module): @@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module): axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), rope_theta: int = 2000, eps: float = 1e-6, + guidance_embeds: bool = True, ): super().__init__() self.out_channels = out_channels or in_channels @@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module): # 2. Combined timestep + guidance embedding self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( - in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, ) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) @@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module): txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, - ) -> Union[torch.Tensor]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): - Input `hidden_states`. - encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ + ): # 0. Handle input arguments if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() @@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module): # 1. Calculate timestep embedding and modulation parameters timestep = timestep.to(hidden_states.dtype) * 1000 - guidance = guidance.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 temb = self.time_guidance_embed(timestep, guidance) diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py index 4eba636..4d6271d 100644 --- a/diffsynth/models/z_image_text_encoder.py +++ b/diffsynth/models/z_image_text_encoder.py @@ -3,38 +3,71 @@ import torch class ZImageTextEncoder(torch.nn.Module): - def __init__(self): + def __init__(self, model_size="4B"): super().__init__() - config = Qwen3Config(**{ - "architectures": [ - "Qwen3ForCausalLM" - ], - "attention_bias": False, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "head_dim": 128, - "hidden_act": "silu", - "hidden_size": 2560, - "initializer_range": 0.02, - "intermediate_size": 9728, - "max_position_embeddings": 40960, - "max_window_layers": 36, - "model_type": "qwen3", - "num_attention_heads": 32, - "num_hidden_layers": 36, - "num_key_value_heads": 8, - "rms_norm_eps": 1e-06, - "rope_scaling": None, - "rope_theta": 1000000, - "sliding_window": None, - "tie_word_embeddings": True, - "torch_dtype": "bfloat16", - "transformers_version": "4.51.0", - "use_cache": True, - "use_sliding_window": False, - "vocab_size": 151936 - }) + config_dict = { + "4B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), + "8B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": False, + "transformers_version": "4.56.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }) + } + config = config_dict[model_size] self.model = Qwen3Model(config) def forward(self, *args, **kwargs): diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 8b00469..b736625 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -10,10 +10,11 @@ from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput -from transformers import AutoProcessor +from transformers import AutoProcessor, AutoTokenizer from ..models.flux2_text_encoder import Flux2TextEncoder from ..models.flux2_dit import Flux2DiT from ..models.flux2_vae import Flux2VAE +from ..models.z_image_text_encoder import ZImageTextEncoder class Flux2ImagePipeline(BasePipeline): @@ -25,6 +26,7 @@ class Flux2ImagePipeline(BasePipeline): ) self.scheduler = FlowMatchScheduler("FLUX.2") self.text_encoder: Flux2TextEncoder = None + self.text_encoder_qwen3: ZImageTextEncoder = None self.dit: Flux2DiT = None self.vae: Flux2VAE = None self.tokenizer: AutoProcessor = None @@ -32,6 +34,7 @@ class Flux2ImagePipeline(BasePipeline): self.units = [ Flux2Unit_ShapeChecker(), Flux2Unit_PromptEmbedder(), + Flux2Unit_Qwen3PromptEmbedder(), Flux2Unit_NoiseInitializer(), Flux2Unit_InputImageEmbedder(), Flux2Unit_ImageIDs(), @@ -53,11 +56,12 @@ class Flux2ImagePipeline(BasePipeline): # Fetch models pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") + pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder") pipe.dit = model_pool.fetch_model("flux2_dit") pipe.vae = model_pool.fetch_model("flux2_vae") if tokenizer_config is not None: tokenizer_config.download_if_necessary() - pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path) + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() @@ -275,6 +279,10 @@ class Flux2Unit_PromptEmbedder(PipelineUnit): return prompt_embeds, text_ids def process(self, pipe: Flux2ImagePipeline, prompt): + # Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder) + if pipe.text_encoder_qwen3 is not None: + return {} + pipe.load_models_to_device(self.onload_model_names) prompt_embeds, text_ids = self.encode_prompt( pipe.text_encoder, pipe.tokenizer, prompt, @@ -283,6 +291,136 @@ class Flux2Unit_PromptEmbedder(PipelineUnit): return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} +class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder_qwen3",) + ) + self.hidden_states_layers = (9, 18, 27) # Qwen3 layers + + def get_qwen3_prompt_embeds( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + with torch.inference_mode(): + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_qwen3_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + # Check if Qwen3 text encoder is available + if pipe.text_encoder_qwen3 is None: + return {} + + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder_qwen3, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + class Flux2Unit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( diff --git a/diffsynth/utils/state_dict_converters/z_image_text_encoder.py b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py new file mode 100644 index 0000000..b114613 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py @@ -0,0 +1,6 @@ +def ZImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name != "lm_head.weight": + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/docs/en/Model_Details/FLUX2.md b/docs/en/Model_Details/FLUX2.md index fd5e56d..70ccf21 100644 --- a/docs/en/Model_Details/FLUX2.md +++ b/docs/en/Model_Details/FLUX2.md @@ -2,6 +2,15 @@ FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs. +## Model Lineage + +```mermaid +graph LR; + FLUX.2-Series-->black-forest-labs/FLUX.2-dev; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B; +``` + ## Installation Before using this project for model inference and training, please install DiffSynth-Studio first. @@ -50,16 +59,18 @@ image.save("image.jpg") ## Model Overview -| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training | -| - | - | - | - | - | -| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) | +| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| Special Training Scripts: -* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/) -* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/) -* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/) -* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) +* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md) +* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md) +* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md) +* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md) ## Model Inference @@ -135,4 +146,4 @@ We have built a sample image dataset for your testing. You can download this dat modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). diff --git a/docs/zh/Model_Details/FLUX2.md b/docs/zh/Model_Details/FLUX2.md index ad4df27..cf91054 100644 --- a/docs/zh/Model_Details/FLUX2.md +++ b/docs/zh/Model_Details/FLUX2.md @@ -2,6 +2,15 @@ FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。 +## 模型血缘 + +```mermaid +graph LR; + FLUX.2-Series-->black-forest-labs/FLUX.2-dev; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B; +``` + ## 安装 在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 @@ -50,16 +59,18 @@ image.save("image.jpg") ## 模型总览 -|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| -|-|-|-|-|-| -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| 特殊训练脚本: -* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/) -* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/) -* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/) -* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) +* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md) +* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md) +* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md) +* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md) ## 模型推理 @@ -135,4 +146,4 @@ FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/exam modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 \ No newline at end of file +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 diff --git a/examples/flux2/model_inference/FLUX.2-klein-4B.py b/examples/flux2/model_inference/FLUX.2-klein-4B.py new file mode 100644 index 0000000..fbfe33d --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-4B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-klein-9B.py b/examples/flux2/model_inference/FLUX.2-klein-9B.py new file mode 100644 index 0000000..2abf0e7 --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-9B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py new file mode 100644 index 0000000..019f58e --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py new file mode 100644 index 0000000..b629c94 --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_training/full/FLUX.2-klein-4B.sh b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh new file mode 100644 index 0000000..4fa46da --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-4B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/full/FLUX.2-klein-9B.sh b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh new file mode 100644 index 0000000..c89e8f0 --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh new file mode 100644 index 0000000..8f897cc --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-4B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh new file mode 100644 index 0000000..258c5fe --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 30408a1..ea727b8 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -24,7 +24,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): super().__init__() # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) - tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/")) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py new file mode 100644 index 0000000..c5473ab --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py new file mode 100644 index 0000000..09ac4bc --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py new file mode 100644 index 0000000..93fe2fa --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-4B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py new file mode 100644 index 0000000..75470bc --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-9B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg")