wan-refactor

This commit is contained in:
Artiprocher
2025-06-13 13:46:17 +08:00
parent 436a91e0c9
commit 830b1b7202
125 changed files with 5232 additions and 1341 deletions

View File

@@ -1,6 +1,6 @@
import torch
from .wan_video_dit import DiTBlock
from .utils import hash_state_dict_keys
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
@@ -50,7 +50,11 @@ class VaceWanModel(torch.nn.Module):
# vace patch embeddings
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, vace_context, context, t_mod, freqs):
def forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
@@ -58,8 +62,27 @@ class VaceWanModel(torch.nn.Module):
dim=1) for u in c
])
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.vace_blocks:
c = block(c, x, context, t_mod, freqs)
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1]
return hints
@@ -74,4 +97,17 @@ class VaceWanModelDictConverter:
def from_civitai(self, state_dict):
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
return state_dict_
if hash_state_dict_keys(state_dict_) == '3b2726384e4f64837bdf216eea3f310d': # vace 14B
config = {
"vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
"vace_in_dim": 96,
"patch_size": (1, 2, 2),
"has_image_input": False,
"dim": 5120,
"num_heads": 40,
"ffn_dim": 13824,
"eps": 1e-06,
}
else:
config = {}
return state_dict_, config