This commit is contained in:
Artiprocher
2025-06-11 18:48:44 +08:00
parent 7d29ee1fbb
commit 6a833c7134
15 changed files with 332 additions and 8 deletions

View File

@@ -50,7 +50,11 @@ class VaceWanModel(torch.nn.Module):
# vace patch embeddings
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, vace_context, context, t_mod, freqs):
def forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
@@ -58,8 +62,27 @@ class VaceWanModel(torch.nn.Module):
dim=1) for u in c
])
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.vace_blocks:
c = block(c, x, context, t_mod, freqs)
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1]
return hints

View File

@@ -1,4 +1,4 @@
import torch, warnings, glob, os
import torch, warnings, glob, os, types
import numpy as np
from PIL import Image
from einops import repeat, reduce
@@ -373,6 +373,17 @@ class WanVideoPipeline(BasePipeline):
),
vram_limit=vram_limit,
)
def enable_usp(self):
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@staticmethod
@@ -384,6 +395,7 @@ class WanVideoPipeline(BasePipeline):
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
):
# Redirect model path
if redirect_common_files:
@@ -616,16 +628,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else: