mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
import torch
|
|
from .wan_video_dit import DiTBlock
|
|
|
|
|
|
class VaceWanAttentionBlock(DiTBlock):
|
|
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
|
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
|
self.block_id = block_id
|
|
if block_id == 0:
|
|
self.before_proj = torch.nn.Linear(self.dim, self.dim)
|
|
self.after_proj = torch.nn.Linear(self.dim, self.dim)
|
|
|
|
def forward(self, c, x, context, t_mod, freqs):
|
|
if self.block_id == 0:
|
|
c = self.before_proj(c) + x
|
|
all_c = []
|
|
else:
|
|
all_c = list(torch.unbind(c))
|
|
c = all_c.pop(-1)
|
|
c = super().forward(c, context, t_mod, freqs)
|
|
c_skip = self.after_proj(c)
|
|
all_c += [c_skip, c]
|
|
c = torch.stack(all_c)
|
|
return c
|
|
|
|
|
|
class VaceWanModel(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
|
vace_in_dim=96,
|
|
patch_size=(1, 2, 2),
|
|
has_image_input=False,
|
|
dim=1536,
|
|
num_heads=12,
|
|
ffn_dim=8960,
|
|
eps=1e-6,
|
|
):
|
|
super().__init__()
|
|
self.vace_layers = vace_layers
|
|
self.vace_in_dim = vace_in_dim
|
|
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
|
|
|
# vace blocks
|
|
self.vace_blocks = torch.nn.ModuleList([
|
|
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
|
for i in self.vace_layers
|
|
])
|
|
|
|
# vace patch embeddings
|
|
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
def forward(
|
|
self, x, vace_context, context, t_mod, freqs,
|
|
use_gradient_checkpointing: bool = False,
|
|
use_gradient_checkpointing_offload: bool = False,
|
|
):
|
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
|
c = torch.cat([
|
|
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
|
|
dim=1) for u in c
|
|
])
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
return custom_forward
|
|
|
|
for block in self.vace_blocks:
|
|
if use_gradient_checkpointing_offload:
|
|
with torch.autograd.graph.save_on_cpu():
|
|
c = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
c, x, context, t_mod, freqs,
|
|
use_reentrant=False,
|
|
)
|
|
elif use_gradient_checkpointing:
|
|
c = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
c, x, context, t_mod, freqs,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
c = block(c, x, context, t_mod, freqs)
|
|
hints = torch.unbind(c)[:-1]
|
|
return hints
|