mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Revert "Wan refactor"
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -131,10 +131,6 @@ model_loader_configs = [
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,45 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class GeneralLoRALoader:
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
updated_num = 0
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
state_dict = module.state_dict()
|
||||
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||
module.load_state_dict(state_dict)
|
||||
updated_num += 1
|
||||
print(f"{updated_num} tensors are updated by LoRA.")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=device) as f:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -5,10 +5,6 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
from dchen.camera_adapter import SimpleAdapter
|
||||
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
@@ -276,9 +272,6 @@ class WanModel(torch.nn.Module):
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
has_image_pos_emb: bool = False,
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -310,22 +303,10 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||
if has_ref_conv:
|
||||
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
||||
self.has_image_pos_emb = has_image_pos_emb
|
||||
self.has_ref_conv = has_ref_conv
|
||||
|
||||
if add_control_adapter:
|
||||
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
else:
|
||||
self.control_adapter = None
|
||||
|
||||
def patchify(self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None):
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||
y_camera = self.control_adapter(control_camera_latents_input)
|
||||
x = [u + v for u, v in zip(x, y_camera)]
|
||||
x = x[0].unsqueeze(0)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
@@ -551,7 +532,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
# 1.3B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
@@ -566,7 +546,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
# 14B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
@@ -595,74 +574,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||
# 1.3B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||
# 14B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
||||
# 1.3B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
||||
# 14B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -774,11 +774,18 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_states, device)
|
||||
return video
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
videos = torch.stack(videos)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -68,7 +68,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -238,18 +237,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
return latents
|
||||
|
||||
|
||||
def prepare_reference_image(self, reference_image, height, width):
|
||||
if reference_image is not None:
|
||||
self.load_models_to_device(["vae"])
|
||||
reference_image = reference_image.resize((width, height))
|
||||
reference_image = self.preprocess_images([reference_image])
|
||||
reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
reference_latents = self.vae.encode(reference_image, device=self.device)
|
||||
return {"reference_latents": reference_latents}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if control_video is not None:
|
||||
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -352,7 +339,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
end_image=None,
|
||||
input_video=None,
|
||||
control_video=None,
|
||||
reference_image=None,
|
||||
vace_video=None,
|
||||
vace_video_mask=None,
|
||||
vace_reference_image=None,
|
||||
@@ -412,9 +398,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
# Reference image
|
||||
reference_image_kwargs = self.prepare_reference_image(reference_image, height, width)
|
||||
|
||||
# ControlNet
|
||||
if control_video is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
@@ -452,14 +435,14 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -543,7 +526,6 @@ def model_fn_wan_video(
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
tea_cache: TeaCache = None,
|
||||
@@ -570,12 +552,6 @@ def model_fn_wan_video(
|
||||
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
|
||||
x = torch.concat([reference_latents, x], dim=1)
|
||||
f += 1
|
||||
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
@@ -604,10 +580,6 @@ def model_fn_wan_video(
|
||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
if reference_latents is not None:
|
||||
x = x[:, reference_latents.shape[1]:]
|
||||
f -= 1
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user