From cb8de6be1b248c286d5803c7f2895846c3fb2b82 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 3 Sep 2025 12:03:49 +0800 Subject: [PATCH 1/4] move training code to base trainer --- diffsynth/trainers/utils.py | 50 ++++++++++++++++++- examples/flux/model_training/train.py | 35 +++---------- examples/qwen_image/model_training/train.py | 44 +++------------- .../model_training/train_data_process.py | 44 +++------------- examples/wanvideo/model_training/train.py | 35 +++---------- 5 files changed, 79 insertions(+), 129 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index f0577a2..f133983 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -1,4 +1,6 @@ import imageio, os, torch, warnings, torchvision, argparse, json +from ..utils import ModelConfig +from ..models.utils import load_state_dict from peft import LoraConfig, inject_adapter_in_model from PIL import Image import pandas as pd @@ -424,7 +426,53 @@ class DiffusionTrainingModule(torch.nn.Module): if isinstance(data[key], torch.Tensor): data[key] = data[key].to(device) return data - + + + def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False): + offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] + return model_configs + + + def switch_pipe_to_training_mode( + self, + pipe, + trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None, + enable_fp8_training=False, + ): + # Scheduler + pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Enable FP8 if pipeline supports + if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"): + pipe._enable_fp8_lora_training(torch.float8_e4m3fn) + + # Add LoRA to the base models + if lora_base_model is not None: + model = self.add_lora_to_model( + getattr(pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank, + upcast_dtype=pipe.torch_dtype, + ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") + setattr(pipe, lora_base_model, model) class ModelLogger: diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index a0db5c4..11383a5 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -20,37 +20,16 @@ class FluxTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - # Reset training scheduler - self.pipe.scheduler.set_timesteps(1000, training=True) + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=False, + ) - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index b89e679..553a25c 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -22,46 +22,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] - + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=enable_fp8_training, + ) - # Enable FP8 - if enable_fp8_training: - self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) - - # Reset training scheduler (do it in each training step) - self.pipe.scheduler.set_timesteps(1000, training=True) - - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank, - upcast_dtype=self.pipe.torch_dtype, - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload diff --git a/examples/qwen_image/model_training/train_data_process.py b/examples/qwen_image/model_training/train_data_process.py index 0f4f4fb..9adf968 100644 --- a/examples/qwen_image/model_training/train_data_process.py +++ b/examples/qwen_image/model_training/train_data_process.py @@ -22,46 +22,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] - + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=enable_fp8_training, + ) - # Enable FP8 - if enable_fp8_training: - self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) - - # Reset training scheduler (do it in each training step) - self.pipe.scheduler.set_timesteps(1000, training=True) - - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank, - upcast_dtype=self.pipe.torch_dtype, - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 7df70da..f0052dc 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -21,37 +21,16 @@ class WanTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - # Reset training scheduler - self.pipe.scheduler.set_timesteps(1000, training=True) + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=False, + ) - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload From 144365b07d76bb25c84ecdf9fc39098fc053fb7e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 4 Sep 2025 15:18:56 +0800 Subject: [PATCH 2/4] merge data process to training script --- diffsynth/trainers/utils.py | 25 +++- examples/flux/model_training/train.py | 2 +- .../model_training/lora/Qwen-Image-Splited.sh | 5 +- examples/qwen_image/model_training/train.py | 22 ++- .../model_training/train_data_process.py | 126 ------------------ examples/wanvideo/model_training/train.py | 2 +- 6 files changed, 35 insertions(+), 147 deletions(-) delete mode 100644 examples/qwen_image/model_training/train_data_process.py diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index f133983..0711176 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -520,14 +520,26 @@ def launch_training_task( dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, + learning_rate: float = 1e-5, + weight_decay: float = 1e-2, num_workers: int = 8, save_steps: int = None, num_epochs: int = 1, gradient_accumulation_steps: int = 1, find_unused_parameters: bool = False, + args = None, ): + if args is not None: + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_workers = args.dataset_num_workers + save_steps = args.save_steps + num_epochs = args.num_epochs + gradient_accumulation_steps = args.gradient_accumulation_steps + find_unused_parameters = args.find_unused_parameters + + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, @@ -557,8 +569,12 @@ def launch_data_process_task( model: DiffusionTrainingModule, model_logger: ModelLogger, num_workers: int = 8, + args = None, ): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + if args is not None: + num_workers = args.dataset_num_workers + + dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) accelerator = Accelerator() model, dataloader = accelerator.prepare(model, dataloader) @@ -568,7 +584,7 @@ def launch_data_process_task( folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) os.makedirs(folder, exist_ok=True) save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") - data = model(data) + data = model(data, return_inputs=True) torch.save(data, save_path) @@ -671,4 +687,5 @@ def qwen_image_parser(): parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.") + parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") return parser diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 11383a5..e2a60a7 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -25,7 +25,7 @@ class FluxTrainingModule(DiffusionTrainingModule): # Training mode self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, + self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, enable_fp8_training=False, ) diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh index b456ca1..a5d7cde 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh @@ -1,11 +1,12 @@ -accelerate launch examples/qwen_image/model_training/train_data_process.py \ +accelerate launch examples/qwen_image/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ --max_pixels 1048576 \ --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ --output_path "./models/train/Qwen-Image_lora_cache" \ --use_gradient_checkpointing \ - --dataset_num_workers 8 + --dataset_num_workers 8 \ + --task data_process accelerate launch examples/qwen_image/model_training/train.py \ --dataset_base_path models/train/Qwen-Image_lora_cache \ diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 553a25c..1370b32 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -2,7 +2,7 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -29,7 +29,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): # Training mode self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, + self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, enable_fp8_training=enable_fp8_training, ) @@ -81,9 +81,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule): return {**inputs_shared, **inputs_posi} - def forward(self, data, inputs=None): + def forward(self, data, inputs=None, return_inputs=False): if inputs is None: inputs = self.forward_preprocess(data) else: inputs = self.transfer_data_to_device(inputs, self.pipe.device) + if return_inputs: return inputs models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(**models, **inputs) return loss @@ -123,13 +124,8 @@ if __name__ == "__main__": enable_fp8_training=args.enable_fp8_training, ) model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - launch_training_task( - dataset, model, model_logger, optimizer, scheduler, - num_epochs=args.num_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - save_steps=args.save_steps, - find_unused_parameters=args.find_unused_parameters, - num_workers=args.dataset_num_workers, - ) + launcher_map = { + "sft": launch_training_task, + "data_process": launch_data_process_task + } + launcher_map[args.task](dataset, model, model_logger, args=args) diff --git a/examples/qwen_image/model_training/train_data_process.py b/examples/qwen_image/model_training/train_data_process.py deleted file mode 100644 index 9adf968..0000000 --- a/examples/qwen_image/model_training/train_data_process.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch, os, json -from diffsynth import load_state_dict -from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig -from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_data_process_task, qwen_image_parser -from diffsynth.trainers.unified_dataset import UnifiedDataset -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - - -class QwenImageTrainingModule(DiffusionTrainingModule): - def __init__( - self, - model_paths=None, model_id_with_origin_paths=None, - tokenizer_path=None, processor_path=None, - trainable_models=None, - lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, - use_gradient_checkpointing=True, - use_gradient_checkpointing_offload=False, - extra_inputs=None, - enable_fp8_training=False, - ): - super().__init__() - # Load models - model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) - tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) - processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) - self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) - - # Training mode - self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, - lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, - enable_fp8_training=enable_fp8_training, - ) - - # Store other configs - self.use_gradient_checkpointing = use_gradient_checkpointing - self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload - self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] - - - def forward_preprocess(self, data): - # CFG-sensitive parameters - inputs_posi = {"prompt": data["prompt"]} - inputs_nega = {"negative_prompt": ""} - - # CFG-unsensitive parameters - inputs_shared = { - # Assume you are using this pipeline for inference, - # please fill in the input parameters. - "input_image": data["image"], - "height": data["image"].size[1], - "width": data["image"].size[0], - # Please do not modify the following parameters - # unless you clearly know what this will cause. - "cfg_scale": 1, - "rand_device": self.pipe.device, - "use_gradient_checkpointing": self.use_gradient_checkpointing, - "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, - "edit_image_auto_resize": True, - } - - # Extra inputs - controlnet_input, blockwise_controlnet_input = {}, {} - for extra_input in self.extra_inputs: - if extra_input.startswith("blockwise_controlnet_"): - blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] - elif extra_input.startswith("controlnet_"): - controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] - else: - inputs_shared[extra_input] = data[extra_input] - if len(controlnet_input) > 0: - inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)] - if len(blockwise_controlnet_input) > 0: - inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)] - - # Pipeline units will automatically process the input parameters. - for unit in self.pipe.units: - inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) - return {**inputs_shared, **inputs_posi} - - - def forward(self, data, inputs=None): - if inputs is None: inputs = self.forward_preprocess(data) - return inputs - - - -if __name__ == "__main__": - parser = qwen_image_parser() - args = parser.parse_args() - dataset = UnifiedDataset( - base_path=args.dataset_base_path, - metadata_path=args.dataset_metadata_path, - repeat=1, # Set repeat = 1 - data_file_keys=args.data_file_keys.split(","), - main_data_operator=UnifiedDataset.default_image_operator( - base_path=args.dataset_base_path, - max_pixels=args.max_pixels, - height=args.height, - width=args.width, - height_division_factor=16, - width_division_factor=16, - ) - ) - model = QwenImageTrainingModule( - model_paths=args.model_paths, - model_id_with_origin_paths=args.model_id_with_origin_paths, - tokenizer_path=args.tokenizer_path, - processor_path=args.processor_path, - trainable_models=args.trainable_models, - lora_base_model=args.lora_base_model, - lora_target_modules=args.lora_target_modules, - lora_rank=args.lora_rank, - lora_checkpoint=args.lora_checkpoint, - use_gradient_checkpointing=args.use_gradient_checkpointing, - use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, - extra_inputs=args.extra_inputs, - enable_fp8_training=args.enable_fp8_training, - ) - model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) - launch_data_process_task( - dataset, model, model_logger, - num_workers=args.dataset_num_workers, - ) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index f0052dc..b811f67 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -26,7 +26,7 @@ class WanTrainingModule(DiffusionTrainingModule): # Training mode self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, + self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, enable_fp8_training=False, ) From d049fb6d1dd682964f55e449874f9780ffdb58cb Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 4 Sep 2025 15:44:37 +0800 Subject: [PATCH 3/4] bugfix --- examples/flux/model_training/train.py | 9 +-------- examples/wanvideo/model_training/train.py | 9 +-------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index e2a60a7..568c77a 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -119,11 +119,4 @@ if __name__ == "__main__": ) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - launch_training_task( - dataset, model, model_logger, optimizer, scheduler, - num_epochs=args.num_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - save_steps=args.save_steps, - find_unused_parameters=args.find_unused_parameters, - num_workers=args.dataset_num_workers, - ) + launch_training_task(dataset, model, model_logger, args=args) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index b811f67..b9b7d8c 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -128,11 +128,4 @@ if __name__ == "__main__": ) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - launch_training_task( - dataset, model, model_logger, optimizer, scheduler, - num_epochs=args.num_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - save_steps=args.save_steps, - find_unused_parameters=args.find_unused_parameters, - num_workers=args.dataset_num_workers, - ) + launch_training_task(dataset, model, model_logger, args=args) From 42ec7b08eb1a895c5c0a3528cca8b2b5504d158e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 4 Sep 2025 15:45:39 +0800 Subject: [PATCH 4/4] bugfix --- examples/flux/model_training/train.py | 2 -- examples/wanvideo/model_training/train.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 568c77a..4a82228 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -117,6 +117,4 @@ if __name__ == "__main__": remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, ) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task(dataset, model, model_logger, args=args) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index b9b7d8c..37494e7 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -126,6 +126,4 @@ if __name__ == "__main__": args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt ) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task(dataset, model, model_logger, args=args)