From b6ccb362b9e9ee3a14303b82494ae1d6e14e989f Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Jan 2026 16:56:14 +0800 Subject: [PATCH] support flux.2 klein --- diffsynth/configs/model_configs.py | 15 ++ diffsynth/diffusion/training_module.py | 20 ++- diffsynth/models/z_image_text_encoder.py | 95 ++++++++---- diffsynth/pipelines/flux2_image.py | 137 ++++++++++++++++++ .../z_image_text_encoder.py | 6 + .../flux2/model_inference/FLUX.2-klein-4B.py | 17 +++ .../flux2/model_inference/FLUX.2-klein-9B.py | 17 +++ .../FLUX.2-klein-4B.py | 27 ++++ .../FLUX.2-klein-9B.py | 27 ++++ .../model_training/full/FLUX.2-klein-4B.sh | 13 ++ .../model_training/full/FLUX.2-klein-9B.sh | 13 ++ .../model_training/lora/FLUX.2-klein-4B.sh | 15 ++ .../model_training/lora/FLUX.2-klein-9B.sh | 15 ++ examples/flux2/model_training/train.py | 2 +- .../validate_full/FLUX.2-klein-4B.py | 20 +++ .../validate_full/FLUX.2-klein-9B.py | 20 +++ .../validate_lora/FLUX.2-klein-4B.py | 18 +++ .../validate_lora/FLUX.2-klein-9B.py | 18 +++ 18 files changed, 460 insertions(+), 35 deletions(-) create mode 100644 diffsynth/utils/state_dict_converters/z_image_text_encoder.py create mode 100644 examples/flux2/model_inference/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_inference/FLUX.2-klein-9B.py create mode 100644 examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py create mode 100644 examples/flux2/model_training/full/FLUX.2-klein-4B.sh create mode 100644 examples/flux2/model_training/full/FLUX.2-klein-9B.sh create mode 100644 examples/flux2/model_training/lora/FLUX.2-klein-4B.sh create mode 100644 examples/flux2/model_training/lora/FLUX.2-klein-9B.sh create mode 100644 examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py create mode 100644 examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index cc23fb9..c93f5e9 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -517,6 +517,21 @@ flux2_series = [ "model_class": "diffsynth.models.flux2_dit.Flux2DiT", "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20} }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "9195f3ea256fcd0ae6d929c203470754", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "8B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "39c6fc48f07bebecedbbaa971ff466c8", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24} + }, ] z_image_series = [ diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index e3b3329..b658866 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -1,4 +1,4 @@ -import torch, json +import torch, json, os from ..core import ModelConfig, load_state_dict from ..utils.controlnet import ControlNetInput from peft import LoraConfig, inject_adapter_in_model @@ -127,15 +127,29 @@ class DiffusionTrainingModule(torch.nn.Module): if model_id_with_origin_paths is not None: model_id_with_origin_paths = model_id_with_origin_paths.split(",") for model_id_with_origin_path in model_id_with_origin_paths: - model_id, origin_file_pattern = model_id_with_origin_path.split(":") vram_config = self.parse_vram_config( fp8=model_id_with_origin_path in fp8_models, offload=model_id_with_origin_path in offload_models, device=device ) - model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config)) + config = self.parse_path_or_model_id(model_id_with_origin_path) + model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config)) return model_configs + + def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None): + if model_id_with_origin_path is None: + return default_value + elif os.path.exists(model_id_with_origin_path): + return ModelConfig(path=model_id_with_origin_path) + else: + if ":" not in model_id_with_origin_path: + raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.") + split_id = model_id_with_origin_path.rfind(":") + model_id = model_id_with_origin_path[:split_id] + origin_file_pattern = model_id_with_origin_path[split_id + 1:] + return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern) + def switch_pipe_to_training_mode( self, diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py index 4eba636..4d6271d 100644 --- a/diffsynth/models/z_image_text_encoder.py +++ b/diffsynth/models/z_image_text_encoder.py @@ -3,38 +3,71 @@ import torch class ZImageTextEncoder(torch.nn.Module): - def __init__(self): + def __init__(self, model_size="4B"): super().__init__() - config = Qwen3Config(**{ - "architectures": [ - "Qwen3ForCausalLM" - ], - "attention_bias": False, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "head_dim": 128, - "hidden_act": "silu", - "hidden_size": 2560, - "initializer_range": 0.02, - "intermediate_size": 9728, - "max_position_embeddings": 40960, - "max_window_layers": 36, - "model_type": "qwen3", - "num_attention_heads": 32, - "num_hidden_layers": 36, - "num_key_value_heads": 8, - "rms_norm_eps": 1e-06, - "rope_scaling": None, - "rope_theta": 1000000, - "sliding_window": None, - "tie_word_embeddings": True, - "torch_dtype": "bfloat16", - "transformers_version": "4.51.0", - "use_cache": True, - "use_sliding_window": False, - "vocab_size": 151936 - }) + config_dict = { + "4B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), + "8B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": False, + "transformers_version": "4.56.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }) + } + config = config_dict[model_size] self.model = Qwen3Model(config) def forward(self, *args, **kwargs): diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index e94d2c3..b736625 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -14,6 +14,7 @@ from transformers import AutoProcessor, AutoTokenizer from ..models.flux2_text_encoder import Flux2TextEncoder from ..models.flux2_dit import Flux2DiT from ..models.flux2_vae import Flux2VAE +from ..models.z_image_text_encoder import ZImageTextEncoder class Flux2ImagePipeline(BasePipeline): @@ -25,6 +26,7 @@ class Flux2ImagePipeline(BasePipeline): ) self.scheduler = FlowMatchScheduler("FLUX.2") self.text_encoder: Flux2TextEncoder = None + self.text_encoder_qwen3: ZImageTextEncoder = None self.dit: Flux2DiT = None self.vae: Flux2VAE = None self.tokenizer: AutoProcessor = None @@ -32,6 +34,7 @@ class Flux2ImagePipeline(BasePipeline): self.units = [ Flux2Unit_ShapeChecker(), Flux2Unit_PromptEmbedder(), + Flux2Unit_Qwen3PromptEmbedder(), Flux2Unit_NoiseInitializer(), Flux2Unit_InputImageEmbedder(), Flux2Unit_ImageIDs(), @@ -276,6 +279,10 @@ class Flux2Unit_PromptEmbedder(PipelineUnit): return prompt_embeds, text_ids def process(self, pipe: Flux2ImagePipeline, prompt): + # Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder) + if pipe.text_encoder_qwen3 is not None: + return {} + pipe.load_models_to_device(self.onload_model_names) prompt_embeds, text_ids = self.encode_prompt( pipe.text_encoder, pipe.tokenizer, prompt, @@ -284,6 +291,136 @@ class Flux2Unit_PromptEmbedder(PipelineUnit): return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} +class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder_qwen3",) + ) + self.hidden_states_layers = (9, 18, 27) # Qwen3 layers + + def get_qwen3_prompt_embeds( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + with torch.inference_mode(): + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_qwen3_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + # Check if Qwen3 text encoder is available + if pipe.text_encoder_qwen3 is None: + return {} + + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder_qwen3, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + class Flux2Unit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( diff --git a/diffsynth/utils/state_dict_converters/z_image_text_encoder.py b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py new file mode 100644 index 0000000..b114613 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py @@ -0,0 +1,6 @@ +def ZImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name != "lm_head.weight": + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/examples/flux2/model_inference/FLUX.2-klein-4B.py b/examples/flux2/model_inference/FLUX.2-klein-4B.py new file mode 100644 index 0000000..fbfe33d --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-4B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-klein-9B.py b/examples/flux2/model_inference/FLUX.2-klein-9B.py new file mode 100644 index 0000000..2abf0e7 --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-9B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py new file mode 100644 index 0000000..019f58e --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py new file mode 100644 index 0000000..b629c94 --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_training/full/FLUX.2-klein-4B.sh b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh new file mode 100644 index 0000000..4fa46da --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-4B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/full/FLUX.2-klein-9B.sh b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh new file mode 100644 index 0000000..c89e8f0 --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh new file mode 100644 index 0000000..8f897cc --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-4B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh new file mode 100644 index 0000000..258c5fe --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 30408a1..ea727b8 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -24,7 +24,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): super().__init__() # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) - tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/")) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py new file mode 100644 index 0000000..c5473ab --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py new file mode 100644 index 0000000..09ac4bc --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py new file mode 100644 index 0000000..93fe2fa --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-4B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py new file mode 100644 index 0000000..75470bc --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-9B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg")