mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
wans2v inference
This commit is contained in:
@@ -15,6 +15,7 @@ from typing_extensions import Literal
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_dit_s2v import rope_precompute
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
@@ -49,6 +50,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
WanVideoUnit_NoiseInitializer(),
|
||||
WanVideoUnit_S2V(),
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedderVAE(),
|
||||
@@ -127,6 +129,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Conv1d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -254,6 +258,24 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.audio_encoder is not None:
|
||||
# TODO: need check
|
||||
dtype = next(iter(self.audio_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.audio_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def initialize_usp(self):
|
||||
@@ -290,6 +312,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
audio_processor_config: ModelConfig = None,
|
||||
redirect_common_files: bool = True,
|
||||
use_usp=False,
|
||||
):
|
||||
@@ -332,7 +355,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||
|
||||
# Size division factor
|
||||
if pipe.vae is not None:
|
||||
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
||||
@@ -342,7 +366,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||
|
||||
|
||||
if audio_processor_config is not None:
|
||||
audio_processor_config.download_if_necessary(use_usp=use_usp)
|
||||
from transformers import Wav2Vec2Processor
|
||||
pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
|
||||
# Unified Sequence Parallel
|
||||
if use_usp: pipe.enable_usp()
|
||||
return pipe
|
||||
@@ -361,6 +389,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Video-to-video
|
||||
input_video: Optional[list[Image.Image]] = None,
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# Speech-to-video
|
||||
input_audio: Optional[str] = None,
|
||||
audio_sample_rate: Optional[int] = 16000,
|
||||
s2v_pose_video: Optional[list[Image.Image]] = None,
|
||||
# ControlNet
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
@@ -429,6 +461,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"motion_bucket_id": motion_bucket_id,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -868,6 +901,67 @@ class WanVideoUnit_CfgMerger(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class WanVideoUnit_S2V(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("audio_encoder", "vae", )
|
||||
)
|
||||
|
||||
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
|
||||
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(["audio_encoder"])
|
||||
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True)
|
||||
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
|
||||
z, fps=16, batch_frames=num_frames - 1, m=0
|
||||
)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype)
|
||||
if len(audio_embed_bucket.shape) == 3:
|
||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
||||
elif len(audio_embed_bucket.shape) == 4:
|
||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
||||
audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1]
|
||||
return {"audio_input": audio_embed_bucket}
|
||||
|
||||
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
|
||||
pipe.load_models_to_device(["vae"])
|
||||
# TODO: may support input motion latents
|
||||
motion_frames = 73
|
||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||
lat_motion_frames = (motion_frames + 3) // 4
|
||||
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]}
|
||||
|
||||
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
pipe.load_models_to_device(["vae"])
|
||||
if s2v_pose_video is None:
|
||||
input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype)
|
||||
else:
|
||||
input_video = pipe.preprocess_video(s2v_pose_video)
|
||||
# get num_frames-1 frames
|
||||
input_video = input_video[:, :, :num_frames]
|
||||
# pad if not enough frames
|
||||
padding_frames = num_frames - input_video.shape[2]
|
||||
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
|
||||
# encode to latents
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"pose_cond": input_latents[:,:,1:]}
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width")
|
||||
tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
|
||||
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames)
|
||||
inputs_posi.update(audio_input_positive)
|
||||
inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]})
|
||||
|
||||
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride))
|
||||
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride))
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
@@ -987,6 +1081,10 @@ def model_fn_wan_video(
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
audio_input: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
motion_frames: Optional[list] = None,
|
||||
pose_cond: Optional[torch.Tensor] = None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
@@ -1024,7 +1122,21 @@ def model_fn_wan_video(
|
||||
tensor_names=["latents", "y"],
|
||||
batch_size=2 if cfg_merge else 1
|
||||
)
|
||||
|
||||
# wan2.2 s2v
|
||||
if audio_input is not None:
|
||||
return model_fn_wans2v(
|
||||
dit=dit,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
audio_input=audio_input,
|
||||
motion_latents=motion_latents,
|
||||
motion_frames=motion_frames,
|
||||
pose_cond=pose_cond,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
)
|
||||
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
@@ -1143,3 +1255,107 @@ def model_fn_wan_video(
|
||||
f -= 1
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def model_fn_wans2v(
|
||||
dit,
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_input,
|
||||
motion_latents,
|
||||
motion_frames,
|
||||
pose_cond,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False
|
||||
):
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
latents = latents[:, :, 1:]
|
||||
|
||||
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
||||
audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input)
|
||||
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
||||
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
||||
|
||||
# reference image
|
||||
x = latents
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
|
||||
|
||||
grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
|
||||
seq_lens = torch.tensor([x.size(1)], dtype=torch.long)
|
||||
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
|
||||
|
||||
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
|
||||
|
||||
ref_grid_sizes = [[
|
||||
torch.tensor([30, 0, 0]).unsqueeze(0),
|
||||
torch.tensor([31, rh, rw]).unsqueeze(0),
|
||||
torch.tensor([1, rh, rw]).unsqueeze(0),
|
||||
]]
|
||||
original_seq_len = seq_lens[0]
|
||||
seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long)
|
||||
grid_sizes = grid_sizes + ref_grid_sizes
|
||||
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device)
|
||||
mask[:, -ref_latents.shape[1]:] = 1
|
||||
|
||||
b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads
|
||||
pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None)
|
||||
|
||||
x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# t_mod
|
||||
if dit.zero_timestep:
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
e0 = dit.time_projection(e).unflatten(1, (6, dit.dim))
|
||||
if dit.zero_timestep:
|
||||
e = e[:-1]
|
||||
zero_e0 = e0[-1:]
|
||||
e0 = e0[:-1]
|
||||
e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2)
|
||||
e0 = [e0, original_seq_len]
|
||||
# context
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, e0, pre_compute_freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, e0, pre_compute_freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, e0, pre_compute_freqs)
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)
|
||||
|
||||
x = x[:, :original_seq_len]
|
||||
x = dit.head(x, e)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user