diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index f8e7599..aa21034 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -77,6 +77,8 @@ class LoRAFromCivitai: state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora) elif model_resource == "civitai": state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora) + if isinstance(state_dict_lora, tuple): + state_dict_lora = state_dict_lora[0] if len(state_dict_lora) > 0: print(f" {len(state_dict_lora)} tensors are updated.") for name in state_dict_lora: @@ -96,6 +98,8 @@ class LoRAFromCivitai: converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \ else model.__class__.state_dict_converter().from_civitai state_dict_lora_ = converter_fn(state_dict_lora_) + if isinstance(state_dict_lora_, tuple): + state_dict_lora_ = state_dict_lora_[0] if len(state_dict_lora_) == 0: continue for name in state_dict_lora_: diff --git a/examples/train/README.md b/examples/train/README.md index 50c2c1a..7848fdc 100644 --- a/examples/train/README.md +++ b/examples/train/README.md @@ -156,7 +156,7 @@ After training, use `model_manager.load_lora` to load the LoRA for inference. from diffsynth import ModelManager, FluxImagePipeline import torch -model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", file_path_list=[ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2", @@ -164,7 +164,7 @@ model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" ]) model_manager.load_lora("models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0) -pipe = SDXLImagePipeline.from_model_manager(model_manager) +pipe = FluxImagePipeline.from_model_manager(model_manager) torch.manual_seed(0) image = pipe(