diff --git a/diffsynth/utils/lora/flux.py b/diffsynth/utils/lora/flux.py index 2e1d3fd..97599b6 100644 --- a/diffsynth/utils/lora/flux.py +++ b/diffsynth/utils/lora/flux.py @@ -149,6 +149,8 @@ class FluxLoRALoader(GeneralLoRALoader): dtype=state_dict_[name].dtype) else: state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + + mlp = mlp.to(device=state_dict_[name].device) if 'lora_A' in name: param = torch.concat([ state_dict_.pop(name),