mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
Merge pull request #752 from modelscope/qwen-image-lora-fromat
remove default in qwen-image lora
This commit is contained in:
@@ -383,5 +383,20 @@ class WanLoRAConverter:
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageLoRAConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_opensource_format(state_dict, **kwargs):
|
||||||
|
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||||
|
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import torch, os, json
|
import torch, os, json
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
||||||
|
from diffsynth.models.lora import QwenImageLoRAConverter
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -108,6 +109,7 @@ if __name__ == "__main__":
|
|||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||||
)
|
)
|
||||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
|
|||||||
Reference in New Issue
Block a user