mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
support diffusers format wan and other lora
This commit is contained in:
@@ -211,6 +211,8 @@ class GeneralLoRAFromPeft:
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
torch_dtype = torch.float32
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
@@ -228,6 +230,8 @@ class GeneralLoRAFromPeft:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
if target_name.startswith("diffusion_model."):
|
||||
target_name = target_name[len("diffusion_model."):]
|
||||
if target_name not in target_state_dict:
|
||||
return {}
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
@@ -240,10 +244,21 @@ class GeneralLoRAFromPeft:
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
||||
weight = state_dict_model[name].to(torch.float32)
|
||||
lora_weight = state_dict_lora[name].to(
|
||||
dtype=torch.float32,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
state_dict_model[name] = (weight + lora_weight).to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
else:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user