Revert "Wan refactor"

This commit is contained in:
Zhongjie Duan
2025-06-11 17:29:27 +08:00
committed by GitHub
parent 8badd63a2d
commit 40760ab88b
216 changed files with 1332 additions and 4567 deletions

View File

@@ -131,10 +131,6 @@ model_loader_configs = [
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),

View File

@@ -1,45 +0,0 @@
import torch
class GeneralLoRALoader:
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
updated_num = 0
lora_name_dict = self.get_name_dict(state_dict_lora)
for name, module in model.named_modules():
if name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
state_dict = module.state_dict()
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
module.load_state_dict(state_dict)
updated_num += 1
print(f"{updated_num} tensors are updated by LoRA.")

View File

@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
return state_dict
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device=device) as f:
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):

View File

@@ -5,10 +5,6 @@ import math
from typing import Tuple, Optional
from einops import rearrange
from .utils import hash_state_dict_keys
from dchen.camera_adapter import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
@@ -276,9 +272,6 @@ class WanModel(torch.nn.Module):
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
):
super().__init__()
self.dim = dim
@@ -310,22 +303,10 @@ class WanModel(torch.nn.Module):
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
if add_control_adapter:
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None):
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]
x = x[0].unsqueeze(0)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
@@ -551,7 +532,6 @@ class WanModelStateDictConverter:
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
# 1.3B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
@@ -566,7 +546,6 @@ class WanModelStateDictConverter:
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
# 14B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
@@ -595,74 +574,6 @@ class WanModelStateDictConverter:
"eps": 1e-6,
"has_image_pos_emb": True
}
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
# 1.3B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
# 14B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
# 1.3B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
# 14B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
else:
config = {}
return state_dict, config

View File

@@ -774,11 +774,18 @@ class WanVideoVAE(nn.Module):
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
if tiled:
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_states, device)
return video
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
@staticmethod

View File

@@ -68,7 +68,6 @@ class WanVideoPipeline(BasePipeline):
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -238,18 +237,6 @@ class WanVideoPipeline(BasePipeline):
return latents
def prepare_reference_image(self, reference_image, height, width):
if reference_image is not None:
self.load_models_to_device(["vae"])
reference_image = reference_image.resize((width, height))
reference_image = self.preprocess_images([reference_image])
reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
reference_latents = self.vae.encode(reference_image, device=self.device)
return {"reference_latents": reference_latents}
else:
return {}
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
if control_video is not None:
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
@@ -352,7 +339,6 @@ class WanVideoPipeline(BasePipeline):
end_image=None,
input_video=None,
control_video=None,
reference_image=None,
vace_video=None,
vace_video_mask=None,
vace_reference_image=None,
@@ -412,9 +398,6 @@ class WanVideoPipeline(BasePipeline):
else:
image_emb = {}
# Reference image
reference_image_kwargs = self.prepare_reference_image(reference_image, height, width)
# ControlNet
if control_video is not None:
self.load_models_to_device(["image_encoder", "vae"])
@@ -452,14 +435,14 @@ class WanVideoPipeline(BasePipeline):
self.dit, motion_controller=self.motion_controller, vace=self.vace,
x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
)
if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller, vace=self.vace,
x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -543,7 +526,6 @@ def model_fn_wan_video(
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
reference_latents = None,
vace_context = None,
vace_scale = 1.0,
tea_cache: TeaCache = None,
@@ -570,12 +552,6 @@ def model_fn_wan_video(
x, (f, h, w) = dit.patchify(x)
# Reference image
if reference_latents is not None:
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
x = torch.concat([reference_latents, x], dim=1)
f += 1
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
@@ -604,10 +580,6 @@ def model_fn_wan_video(
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
if tea_cache is not None:
tea_cache.store(x)
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.head(x, t)
if use_unified_sequence_parallel:

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More