This commit is contained in:
Artiprocher
2025-08-08 18:55:13 +08:00
parent 051ebf3439
commit 89c4e3bdb6
4 changed files with 1 additions and 11 deletions

View File

@@ -1,7 +1,6 @@
import torch, os, json
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
from diffsynth.models.lora import QwenImageLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -106,11 +105,7 @@ if __name__ == "__main__":
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
)
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(