support qwen-image fp8 lora training

This commit is contained in:
Artiprocher
2025-08-25 20:32:36 +08:00
parent c795e35142
commit ce0b948655
3 changed files with 46 additions and 4 deletions

View File

@@ -338,11 +338,15 @@ class DiffusionTrainingModule(torch.nn.Module):
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
if lora_alpha is None:
lora_alpha = lora_rank
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
if upcast_dtype is not None:
for param in model.parameters():
if param.requires_grad:
param.data = param.to(upcast_dtype)
return model
@@ -555,4 +559,5 @@ def qwen_image_parser():
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
return parser