From ce0b9486555fb33a6e15da101d891e22a59c2017 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 25 Aug 2025 20:32:36 +0800 Subject: [PATCH] support qwen-image fp8 lora training --- diffsynth/pipelines/qwen_image.py | 29 +++++++++++++++++++++ diffsynth/trainers/utils.py | 7 ++++- examples/qwen_image/model_training/train.py | 14 +++++++--- 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index f0a7496..383d9f5 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -150,6 +150,35 @@ class QwenImagePipeline(BasePipeline): return loss + def _enable_fp8_lora_training(self, dtype): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding + from ..models.qwen_image_dit import RMSNorm + from ..models.qwen_image_vae import QwenImageRMS_norm + module_map = { + RMSNorm: AutoWrappedModule, + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.Embedding: AutoWrappedModule, + Qwen2_5_VLRotaryEmbedding: AutoWrappedModule, + Qwen2RMSNorm: AutoWrappedModule, + Qwen2_5_VisionPatchEmbed: AutoWrappedModule, + Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule, + QwenImageRMS_norm: AutoWrappedModule, + } + model_config = dict( + offload_dtype=dtype, + offload_device="cuda", + onload_dtype=dtype, + onload_device="cuda", + computation_dtype=self.torch_dtype, + computation_device="cuda", + ) + enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config) + enable_vram_management(self.dit, module_map=module_map, module_config=model_config) + enable_vram_management(self.vae, module_map=module_map, module_config=model_config) + + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False): self.vram_management_enabled = True if vram_limit is None: diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 065b687..22ea31e 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -338,11 +338,15 @@ class DiffusionTrainingModule(torch.nn.Module): return trainable_param_names - def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): + def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None): if lora_alpha is None: lora_alpha = lora_rank lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) model = inject_adapter_in_model(lora_config, model) + if upcast_dtype is not None: + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(upcast_dtype) return model @@ -555,4 +559,5 @@ def qwen_image_parser(): parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") + parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.") return parser diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 7418661..ee6752d 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -17,21 +17,27 @@ class QwenImageTrainingModule(DiffusionTrainingModule): use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, + enable_fp8_training=False, ): super().__init__() # Load models + offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None model_configs = [] if model_paths is not None: model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path) for path in model_paths] + model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] if model_id_with_origin_paths is not None: model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + # Enable FP8 + if enable_fp8_training: + self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) + # Reset training scheduler (do it in each training step) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -43,7 +49,8 @@ class QwenImageTrainingModule(DiffusionTrainingModule): model = self.add_lora_to_model( getattr(self.pipe, lora_base_model), target_modules=lora_target_modules.split(","), - lora_rank=lora_rank + lora_rank=lora_rank, + upcast_dtype=self.pipe.torch_dtype, ) if lora_checkpoint is not None: state_dict = load_state_dict(lora_checkpoint) @@ -126,6 +133,7 @@ if __name__ == "__main__": use_gradient_checkpointing=args.use_gradient_checkpointing, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, + enable_fp8_training=args.enable_fp8_training, ) model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)