From 144365b07d76bb25c84ecdf9fc39098fc053fb7e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 4 Sep 2025 15:18:56 +0800 Subject: [PATCH] merge data process to training script --- diffsynth/trainers/utils.py | 25 +++- examples/flux/model_training/train.py | 2 +- .../model_training/lora/Qwen-Image-Splited.sh | 5 +- examples/qwen_image/model_training/train.py | 22 ++- .../model_training/train_data_process.py | 126 ------------------ examples/wanvideo/model_training/train.py | 2 +- 6 files changed, 35 insertions(+), 147 deletions(-) delete mode 100644 examples/qwen_image/model_training/train_data_process.py diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index f133983..0711176 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -520,14 +520,26 @@ def launch_training_task( dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, + learning_rate: float = 1e-5, + weight_decay: float = 1e-2, num_workers: int = 8, save_steps: int = None, num_epochs: int = 1, gradient_accumulation_steps: int = 1, find_unused_parameters: bool = False, + args = None, ): + if args is not None: + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_workers = args.dataset_num_workers + save_steps = args.save_steps + num_epochs = args.num_epochs + gradient_accumulation_steps = args.gradient_accumulation_steps + find_unused_parameters = args.find_unused_parameters + + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, @@ -557,8 +569,12 @@ def launch_data_process_task( model: DiffusionTrainingModule, model_logger: ModelLogger, num_workers: int = 8, + args = None, ): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + if args is not None: + num_workers = args.dataset_num_workers + + dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) accelerator = Accelerator() model, dataloader = accelerator.prepare(model, dataloader) @@ -568,7 +584,7 @@ def launch_data_process_task( folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) os.makedirs(folder, exist_ok=True) save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") - data = model(data) + data = model(data, return_inputs=True) torch.save(data, save_path) @@ -671,4 +687,5 @@ def qwen_image_parser(): parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.") + parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") return parser diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 11383a5..e2a60a7 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -25,7 +25,7 @@ class FluxTrainingModule(DiffusionTrainingModule): # Training mode self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, + self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, enable_fp8_training=False, ) diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh index b456ca1..a5d7cde 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh @@ -1,11 +1,12 @@ -accelerate launch examples/qwen_image/model_training/train_data_process.py \ +accelerate launch examples/qwen_image/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ --max_pixels 1048576 \ --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ --output_path "./models/train/Qwen-Image_lora_cache" \ --use_gradient_checkpointing \ - --dataset_num_workers 8 + --dataset_num_workers 8 \ + --task data_process accelerate launch examples/qwen_image/model_training/train.py \ --dataset_base_path models/train/Qwen-Image_lora_cache \ diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 553a25c..1370b32 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -2,7 +2,7 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -29,7 +29,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): # Training mode self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, + self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, enable_fp8_training=enable_fp8_training, ) @@ -81,9 +81,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule): return {**inputs_shared, **inputs_posi} - def forward(self, data, inputs=None): + def forward(self, data, inputs=None, return_inputs=False): if inputs is None: inputs = self.forward_preprocess(data) else: inputs = self.transfer_data_to_device(inputs, self.pipe.device) + if return_inputs: return inputs models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(**models, **inputs) return loss @@ -123,13 +124,8 @@ if __name__ == "__main__": enable_fp8_training=args.enable_fp8_training, ) model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - launch_training_task( - dataset, model, model_logger, optimizer, scheduler, - num_epochs=args.num_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - save_steps=args.save_steps, - find_unused_parameters=args.find_unused_parameters, - num_workers=args.dataset_num_workers, - ) + launcher_map = { + "sft": launch_training_task, + "data_process": launch_data_process_task + } + launcher_map[args.task](dataset, model, model_logger, args=args) diff --git a/examples/qwen_image/model_training/train_data_process.py b/examples/qwen_image/model_training/train_data_process.py deleted file mode 100644 index 9adf968..0000000 --- a/examples/qwen_image/model_training/train_data_process.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch, os, json -from diffsynth import load_state_dict -from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig -from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_data_process_task, qwen_image_parser -from diffsynth.trainers.unified_dataset import UnifiedDataset -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - - -class QwenImageTrainingModule(DiffusionTrainingModule): - def __init__( - self, - model_paths=None, model_id_with_origin_paths=None, - tokenizer_path=None, processor_path=None, - trainable_models=None, - lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, - use_gradient_checkpointing=True, - use_gradient_checkpointing_offload=False, - extra_inputs=None, - enable_fp8_training=False, - ): - super().__init__() - # Load models - model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) - tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) - processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) - self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) - - # Training mode - self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, - lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, - enable_fp8_training=enable_fp8_training, - ) - - # Store other configs - self.use_gradient_checkpointing = use_gradient_checkpointing - self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload - self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] - - - def forward_preprocess(self, data): - # CFG-sensitive parameters - inputs_posi = {"prompt": data["prompt"]} - inputs_nega = {"negative_prompt": ""} - - # CFG-unsensitive parameters - inputs_shared = { - # Assume you are using this pipeline for inference, - # please fill in the input parameters. - "input_image": data["image"], - "height": data["image"].size[1], - "width": data["image"].size[0], - # Please do not modify the following parameters - # unless you clearly know what this will cause. - "cfg_scale": 1, - "rand_device": self.pipe.device, - "use_gradient_checkpointing": self.use_gradient_checkpointing, - "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, - "edit_image_auto_resize": True, - } - - # Extra inputs - controlnet_input, blockwise_controlnet_input = {}, {} - for extra_input in self.extra_inputs: - if extra_input.startswith("blockwise_controlnet_"): - blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] - elif extra_input.startswith("controlnet_"): - controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] - else: - inputs_shared[extra_input] = data[extra_input] - if len(controlnet_input) > 0: - inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)] - if len(blockwise_controlnet_input) > 0: - inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)] - - # Pipeline units will automatically process the input parameters. - for unit in self.pipe.units: - inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) - return {**inputs_shared, **inputs_posi} - - - def forward(self, data, inputs=None): - if inputs is None: inputs = self.forward_preprocess(data) - return inputs - - - -if __name__ == "__main__": - parser = qwen_image_parser() - args = parser.parse_args() - dataset = UnifiedDataset( - base_path=args.dataset_base_path, - metadata_path=args.dataset_metadata_path, - repeat=1, # Set repeat = 1 - data_file_keys=args.data_file_keys.split(","), - main_data_operator=UnifiedDataset.default_image_operator( - base_path=args.dataset_base_path, - max_pixels=args.max_pixels, - height=args.height, - width=args.width, - height_division_factor=16, - width_division_factor=16, - ) - ) - model = QwenImageTrainingModule( - model_paths=args.model_paths, - model_id_with_origin_paths=args.model_id_with_origin_paths, - tokenizer_path=args.tokenizer_path, - processor_path=args.processor_path, - trainable_models=args.trainable_models, - lora_base_model=args.lora_base_model, - lora_target_modules=args.lora_target_modules, - lora_rank=args.lora_rank, - lora_checkpoint=args.lora_checkpoint, - use_gradient_checkpointing=args.use_gradient_checkpointing, - use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, - extra_inputs=args.extra_inputs, - enable_fp8_training=args.enable_fp8_training, - ) - model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) - launch_data_process_task( - dataset, model, model_logger, - num_workers=args.dataset_num_workers, - ) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index f0052dc..b811f67 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -26,7 +26,7 @@ class WanTrainingModule(DiffusionTrainingModule): # Training mode self.switch_pipe_to_training_mode( - self, self.pipe, trainable_models, + self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, enable_fp8_training=False, )