mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
refine code
This commit is contained in:
@@ -288,6 +288,9 @@ class WanModel(torch.nn.Module):
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
seperated_timestep: bool = False,
|
||||
require_vae_embedding: bool = True,
|
||||
require_clip_embedding: bool = True,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -295,6 +298,9 @@ class WanModel(torch.nn.Module):
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
@@ -352,7 +358,6 @@ class WanModel(torch.nn.Module):
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
fused_y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
@@ -366,8 +371,6 @@ class WanModel(torch.nn.Module):
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
if fused_y is not None:
|
||||
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
@@ -690,6 +693,9 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"seperated_timestep": True,
|
||||
"require_clip_embedding": False,
|
||||
"require_vae_embedding": False,
|
||||
"fuse_vae_embedding_in_latents": True,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||
# Wan-AI/Wan2.2-I2V-A14B
|
||||
@@ -705,6 +711,7 @@ class WanModelStateDictConverter:
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
|
||||
Reference in New Issue
Block a user