support qwen-image-edit

This commit is contained in:
mi804
2025-08-18 16:07:45 +08:00
parent 7dc49bd036
commit 9f6922bba9
14 changed files with 212 additions and 19 deletions

View File

@@ -11,7 +11,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None,
tokenizer_path=None, processor_path=None, edit_model=False,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
use_gradient_checkpointing=True,
@@ -27,11 +27,15 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
if tokenizer_path is not None:
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path))
else:
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
if edit_model:
tokenizer_config = None
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
else:
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
processor_config = None
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
# Reset training scheduler (do it in each training step)
self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -115,6 +119,8 @@ if __name__ == "__main__":
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
processor_path=args.processor_path,
edit_model=args.edit_model,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,