mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
Revert "Wan refactor"
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=device) as f:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -5,10 +5,6 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
from dchen.camera_adapter import SimpleAdapter
|
||||
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
@@ -276,9 +272,6 @@ class WanModel(torch.nn.Module):
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
has_image_pos_emb: bool = False,
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -310,22 +303,10 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||
if has_ref_conv:
|
||||
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
||||
self.has_image_pos_emb = has_image_pos_emb
|
||||
self.has_ref_conv = has_ref_conv
|
||||
|
||||
if add_control_adapter:
|
||||
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
else:
|
||||
self.control_adapter = None
|
||||
|
||||
def patchify(self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None):
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||
y_camera = self.control_adapter(control_camera_latents_input)
|
||||
x = [u + v for u, v in zip(x, y_camera)]
|
||||
x = x[0].unsqueeze(0)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
@@ -551,7 +532,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
# 1.3B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
@@ -566,7 +546,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
# 14B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
@@ -595,74 +574,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||
# 1.3B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||
# 14B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
||||
# 1.3B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
||||
# 14B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -774,11 +774,18 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_states, device)
|
||||
return video
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
videos = torch.stack(videos)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user