wans2v inference

This commit is contained in:
mi804
2025-08-27 11:51:56 +08:00
parent 04e39f7de5
commit b541b9bed2
7 changed files with 1134 additions and 4 deletions

View File

@@ -15,6 +15,7 @@ from typing_extensions import Literal
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_dit_s2v import rope_precompute
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_image_encoder import WanImageEncoder
@@ -49,6 +50,7 @@ class WanVideoPipeline(BasePipeline):
self.units = [
WanVideoUnit_ShapeChecker(),
WanVideoUnit_NoiseInitializer(),
WanVideoUnit_S2V(),
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedderVAE(),
@@ -127,6 +129,8 @@ class WanVideoPipeline(BasePipeline):
torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.Conv1d: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -254,6 +258,24 @@ class WanVideoPipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.audio_encoder is not None:
# TODO: need check
dtype = next(iter(self.audio_encoder.parameters())).dtype
enable_vram_management(
self.audio_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
def initialize_usp(self):
@@ -290,6 +312,7 @@ class WanVideoPipeline(BasePipeline):
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
audio_processor_config: ModelConfig = None,
redirect_common_files: bool = True,
use_usp=False,
):
@@ -332,7 +355,8 @@ class WanVideoPipeline(BasePipeline):
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace")
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
# Size division factor
if pipe.vae is not None:
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
@@ -342,7 +366,11 @@ class WanVideoPipeline(BasePipeline):
tokenizer_config.download_if_necessary(use_usp=use_usp)
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
if audio_processor_config is not None:
audio_processor_config.download_if_necessary(use_usp=use_usp)
from transformers import Wav2Vec2Processor
pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
return pipe
@@ -361,6 +389,10 @@ class WanVideoPipeline(BasePipeline):
# Video-to-video
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# Speech-to-video
input_audio: Optional[str] = None,
audio_sample_rate: Optional[int] = 16000,
s2v_pose_video: Optional[list[Image.Image]] = None,
# ControlNet
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
@@ -429,6 +461,7 @@ class WanVideoPipeline(BasePipeline):
"motion_bucket_id": motion_bucket_id,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -868,6 +901,67 @@ class WanVideoUnit_CfgMerger(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_S2V(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("audio_encoder", "vae", )
)
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
return {}
pipe.load_models_to_device(["audio_encoder"])
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True)
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
z, fps=16, batch_frames=num_frames - 1, m=0
)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1]
return {"audio_input": audio_embed_bucket}
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
# TODO: may support input motion latents
motion_frames = 73
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
lat_motion_frames = (motion_frames + 3) // 4
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]}
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
if s2v_pose_video is None:
input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype)
else:
input_video = pipe.preprocess_video(s2v_pose_video)
# get num_frames-1 frames
input_video = input_video[:, :, :num_frames]
# pad if not enough frames
padding_frames = num_frames - input_video.shape[2]
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
# encode to latents
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"pose_cond": input_latents[:,:,1:]}
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None:
return inputs_shared, inputs_posi, inputs_nega
input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width")
tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames)
inputs_posi.update(audio_input_positive)
inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]})
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride))
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride))
return inputs_shared, inputs_posi, inputs_nega
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
@@ -987,6 +1081,10 @@ def model_fn_wan_video(
reference_latents = None,
vace_context = None,
vace_scale = 1.0,
audio_input: Optional[torch.Tensor] = None,
motion_latents: Optional[torch.Tensor] = None,
motion_frames: Optional[list] = None,
pose_cond: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
@@ -1024,7 +1122,21 @@ def model_fn_wan_video(
tensor_names=["latents", "y"],
batch_size=2 if cfg_merge else 1
)
# wan2.2 s2v
if audio_input is not None:
return model_fn_wans2v(
dit=dit,
latents=latents,
timestep=timestep,
context=context,
audio_input=audio_input,
motion_latents=motion_latents,
motion_frames=motion_frames,
pose_cond=pose_cond,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
use_gradient_checkpointing=use_gradient_checkpointing,
)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
@@ -1143,3 +1255,107 @@ def model_fn_wan_video(
f -= 1
x = dit.unpatchify(x, (f, h, w))
return x
def model_fn_wans2v(
dit,
latents,
timestep,
context,
audio_input,
motion_latents,
motion_frames,
pose_cond,
use_gradient_checkpointing_offload=False,
use_gradient_checkpointing=False
):
origin_ref_latents = latents[:, :, 0:1]
latents = latents[:, :, 1:]
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input)
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
# reference image
x = latents
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
seq_lens = torch.tensor([x.size(1)], dtype=torch.long)
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
ref_grid_sizes = [[
torch.tensor([30, 0, 0]).unsqueeze(0),
torch.tensor([31, rh, rw]).unsqueeze(0),
torch.tensor([1, rh, rw]).unsqueeze(0),
]]
original_seq_len = seq_lens[0]
seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long)
grid_sizes = grid_sizes + ref_grid_sizes
x = torch.cat([x, ref_latents], dim=1)
mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device)
mask[:, -ref_latents.shape[1]:] = 1
b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads
pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None)
x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
# t_mod
if dit.zero_timestep:
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
e0 = dit.time_projection(e).unflatten(1, (6, dit.dim))
if dit.zero_timestep:
e = e[:-1]
zero_e0 = e0[-1:]
e0 = e0[:-1]
e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2)
e0 = [e0, original_seq_len]
# context
context = dit.text_embedding(context)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
x,
use_reentrant=False,
)
else:
x = block(x, context, e0, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)
x = x[:, :original_seq_len]
x = dit.head(x, e)
x = dit.unpatchify(x, (f, h, w))
x = torch.cat([origin_ref_latents, x], dim=2)
return x