mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Merge branch 'wan-refactor' into wan-refactor
This commit is contained in:
@@ -50,7 +50,11 @@ class VaceWanModel(torch.nn.Module):
|
||||
# vace patch embeddings
|
||||
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x, vace_context, context, t_mod, freqs):
|
||||
def forward(
|
||||
self, x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
):
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
@@ -58,8 +62,27 @@ class VaceWanModel(torch.nn.Module):
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.vace_blocks:
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, warnings, glob, os
|
||||
import torch, warnings, glob, os, types
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat, reduce
|
||||
@@ -213,6 +213,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_SpeedControl(),
|
||||
WanVideoUnit_VACE(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
]
|
||||
@@ -374,6 +375,30 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
def initialize_usp(self):
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
||||
dist.init_process_group(backend="nccl", init_method="env://")
|
||||
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=1,
|
||||
ulysses_degree=dist.get_world_size(),
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
|
||||
def enable_usp(self):
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
||||
|
||||
for block in self.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -385,6 +410,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
local_model_path: str = "./models",
|
||||
skip_download: bool = False,
|
||||
redirect_common_files: bool = True,
|
||||
use_usp=False,
|
||||
):
|
||||
# Redirect model path
|
||||
if redirect_common_files:
|
||||
@@ -412,6 +438,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Initialize pipeline
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
if use_usp: pipe.initialize_usp()
|
||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
@@ -423,6 +450,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||
|
||||
# Unified Sequence Parallel
|
||||
if use_usp: pipe.enable_usp()
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -483,11 +513,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_image": input_image,
|
||||
@@ -499,7 +529,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift,
|
||||
"sigma_shift": sigma_shift,
|
||||
"motion_bucket_id": motion_bucket_id,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
@@ -620,16 +650,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
|
||||
if input_video is None:
|
||||
return {"latents": noise}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if vace_reference_image is not None:
|
||||
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
@@ -829,6 +863,18 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=())
|
||||
|
||||
def process(self, pipe: WanVideoPipeline):
|
||||
if hasattr(pipe, "use_unified_sequence_parallel"):
|
||||
if pipe.use_unified_sequence_parallel:
|
||||
return {"use_unified_sequence_parallel": True}
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_TeaCache(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
Reference in New Issue
Block a user