refine code

This commit is contained in:
Artiprocher
2025-07-28 11:00:54 +08:00
parent bba44173d2
commit 5f68727ad3
22 changed files with 474 additions and 86 deletions

View File

@@ -288,6 +288,9 @@ class WanModel(torch.nn.Module):
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
):
super().__init__()
self.dim = dim
@@ -295,6 +298,9 @@ class WanModel(torch.nn.Module):
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
@@ -352,7 +358,6 @@ class WanModel(torch.nn.Module):
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
fused_y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
@@ -366,8 +371,6 @@ class WanModel(torch.nn.Module):
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
if fused_y is not None:
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
x, (f, h, w) = self.patchify(x)
@@ -690,6 +693,9 @@ class WanModelStateDictConverter:
"num_layers": 30,
"eps": 1e-6,
"seperated_timestep": True,
"require_clip_embedding": False,
"require_vae_embedding": False,
"fuse_vae_embedding_in_latents": True,
}
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
# Wan-AI/Wan2.2-I2V-A14B
@@ -705,6 +711,7 @@ class WanModelStateDictConverter:
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"require_clip_embedding": False,
}
else:
config = {}