From cb8de6be1b248c286d5803c7f2895846c3fb2b82 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 3 Sep 2025 12:03:49 +0800 Subject: [PATCH] move training code to base trainer --- diffsynth/trainers/utils.py | 50 ++++++++++++++++++- examples/flux/model_training/train.py | 35 +++---------- examples/qwen_image/model_training/train.py | 44 +++------------- .../model_training/train_data_process.py | 44 +++------------- examples/wanvideo/model_training/train.py | 35 +++---------- 5 files changed, 79 insertions(+), 129 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index f0577a2..f133983 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -1,4 +1,6 @@ import imageio, os, torch, warnings, torchvision, argparse, json +from ..utils import ModelConfig +from ..models.utils import load_state_dict from peft import LoraConfig, inject_adapter_in_model from PIL import Image import pandas as pd @@ -424,7 +426,53 @@ class DiffusionTrainingModule(torch.nn.Module): if isinstance(data[key], torch.Tensor): data[key] = data[key].to(device) return data - + + + def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False): + offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] + return model_configs + + + def switch_pipe_to_training_mode( + self, + pipe, + trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None, + enable_fp8_training=False, + ): + # Scheduler + pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Enable FP8 if pipeline supports + if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"): + pipe._enable_fp8_lora_training(torch.float8_e4m3fn) + + # Add LoRA to the base models + if lora_base_model is not None: + model = self.add_lora_to_model( + getattr(pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank, + upcast_dtype=pipe.torch_dtype, + ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") + setattr(pipe, lora_base_model, model) class ModelLogger: diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index a0db5c4..11383a5 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -20,37 +20,16 @@ class FluxTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - # Reset training scheduler - self.pipe.scheduler.set_timesteps(1000, training=True) + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=False, + ) - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index b89e679..553a25c 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -22,46 +22,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] - + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=enable_fp8_training, + ) - # Enable FP8 - if enable_fp8_training: - self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) - - # Reset training scheduler (do it in each training step) - self.pipe.scheduler.set_timesteps(1000, training=True) - - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank, - upcast_dtype=self.pipe.torch_dtype, - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload diff --git a/examples/qwen_image/model_training/train_data_process.py b/examples/qwen_image/model_training/train_data_process.py index 0f4f4fb..9adf968 100644 --- a/examples/qwen_image/model_training/train_data_process.py +++ b/examples/qwen_image/model_training/train_data_process.py @@ -22,46 +22,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] - + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=enable_fp8_training, + ) - # Enable FP8 - if enable_fp8_training: - self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) - - # Reset training scheduler (do it in each training step) - self.pipe.scheduler.set_timesteps(1000, training=True) - - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank, - upcast_dtype=self.pipe.torch_dtype, - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 7df70da..f0052dc 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -21,37 +21,16 @@ class WanTrainingModule(DiffusionTrainingModule): ): super().__init__() # Load models - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - # Reset training scheduler - self.pipe.scheduler.set_timesteps(1000, training=True) + # Training mode + self.switch_pipe_to_training_mode( + self, self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=False, + ) - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank - ) - if lora_checkpoint is not None: - state_dict = load_state_dict(lora_checkpoint) - state_dict = self.mapping_lora_state_dict(state_dict) - load_result = model.load_state_dict(state_dict, strict=False) - print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") - if len(load_result[1]) > 0: - print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") - setattr(self.pipe, lora_base_model, model) - # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload