diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index f45c146..676af03 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -117,6 +117,7 @@ model_loader_configs = [ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 784ec7a..da8302a 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -211,6 +211,8 @@ class GeneralLoRAFromPeft: def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}): device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict) + if torch_dtype == torch.float8_e4m3fn: + torch_dtype = torch.float32 state_dict_ = {} for key in state_dict: if ".lora_B." not in key: @@ -228,6 +230,8 @@ class GeneralLoRAFromPeft: keys.pop(keys.index("lora_B") + 1) keys.pop(keys.index("lora_B")) target_name = ".".join(keys) + if target_name.startswith("diffusion_model."): + target_name = target_name[len("diffusion_model."):] if target_name not in target_state_dict: return {} state_dict_[target_name] = lora_weight.cpu() @@ -240,10 +244,21 @@ class GeneralLoRAFromPeft: if len(state_dict_lora) > 0: print(f" {len(state_dict_lora)} tensors are updated.") for name in state_dict_lora: - state_dict_model[name] += state_dict_lora[name].to( - dtype=state_dict_model[name].dtype, - device=state_dict_model[name].device - ) + if state_dict_model[name].dtype == torch.float8_e4m3fn: + weight = state_dict_model[name].to(torch.float32) + lora_weight = state_dict_lora[name].to( + dtype=torch.float32, + device=state_dict_model[name].device + ) + state_dict_model[name] = (weight + lora_weight).to( + dtype=state_dict_model[name].dtype, + device=state_dict_model[name].device + ) + else: + state_dict_model[name] += state_dict_lora[name].to( + dtype=state_dict_model[name].dtype, + device=state_dict_model[name].device + ) model.load_state_dict(state_dict_model) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 3435391..fa59dc6 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -737,7 +737,80 @@ class WanModelStateDictConverter: pass def from_diffusers(self, state_dict): - return state_dict + rename_dict = {"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": + config = { + "model_type": "t2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + else: + config = {} + return state_dict_, config def from_civitai(self, state_dict): if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":