Merge branch 'wan-refactor' into wan-refactor

This commit is contained in:
Zhongjie Duan
2025-06-12 18:50:50 +08:00
committed by GitHub
28 changed files with 678 additions and 90 deletions

View File

@@ -50,7 +50,11 @@ class VaceWanModel(torch.nn.Module):
# vace patch embeddings
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, vace_context, context, t_mod, freqs):
def forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
@@ -58,8 +62,27 @@ class VaceWanModel(torch.nn.Module):
dim=1) for u in c
])
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.vace_blocks:
c = block(c, x, context, t_mod, freqs)
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1]
return hints

View File

@@ -1,4 +1,4 @@
import torch, warnings, glob, os
import torch, warnings, glob, os, types
import numpy as np
from PIL import Image
from einops import repeat, reduce
@@ -213,6 +213,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_FunReference(),
WanVideoUnit_SpeedControl(),
WanVideoUnit_VACE(),
WanVideoUnit_UnifiedSequenceParallel(),
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
]
@@ -374,6 +375,30 @@ class WanVideoPipeline(BasePipeline):
),
vram_limit=vram_limit,
)
def initialize_usp(self):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend="nccl", init_method="env://")
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
def enable_usp(self):
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@staticmethod
@@ -385,6 +410,7 @@ class WanVideoPipeline(BasePipeline):
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
):
# Redirect model path
if redirect_common_files:
@@ -412,6 +438,7 @@ class WanVideoPipeline(BasePipeline):
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: pipe.initialize_usp()
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
pipe.vae = model_manager.fetch_model("wan_video_vae")
@@ -423,6 +450,9 @@ class WanVideoPipeline(BasePipeline):
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
return pipe
@@ -483,11 +513,11 @@ class WanVideoPipeline(BasePipeline):
# Inputs
inputs_posi = {
"prompt": prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_shared = {
"input_image": input_image,
@@ -499,7 +529,7 @@ class WanVideoPipeline(BasePipeline):
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
"num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift,
"sigma_shift": sigma_shift,
"motion_bucket_id": motion_bucket_id,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
@@ -620,16 +650,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
@@ -829,6 +863,18 @@ class WanVideoUnit_VACE(PipelineUnit):
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
def __init__(self):
super().__init__(input_params=())
def process(self, pipe: WanVideoPipeline):
if hasattr(pipe, "use_unified_sequence_parallel"):
if pipe.use_unified_sequence_parallel:
return {"use_unified_sequence_parallel": True}
return {}
class WanVideoUnit_TeaCache(PipelineUnit):
def __init__(self):
super().__init__(