mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
461 lines
20 KiB
Python
461 lines
20 KiB
Python
import sys
|
|
import torch, types
|
|
from PIL import Image
|
|
from typing import Optional, Union
|
|
from einops import rearrange
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
|
|
from ..core.device.npu_compatible_device import get_device_type
|
|
from ..diffusion import FlowMatchScheduler
|
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
|
|
|
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm
|
|
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
|
|
from ..models.wan_video_vae import WanVideoVAE
|
|
from ..models.mova_audio_dit import MovaAudioDit
|
|
from ..models.mova_audio_vae import DacVAE
|
|
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
|
|
from ..utils.data.audio import convert_to_mono, resample_waveform
|
|
|
|
|
|
class MovaAudioVideoPipeline(BasePipeline):
|
|
|
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
|
super().__init__(
|
|
device=device, torch_dtype=torch_dtype,
|
|
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
|
)
|
|
self.scheduler = FlowMatchScheduler("Wan")
|
|
self.tokenizer: HuggingfaceTokenizer = None
|
|
self.text_encoder: WanTextEncoder = None
|
|
self.video_dit: WanModel = None # high noise model
|
|
self.video_dit2: WanModel = None # low noise model
|
|
self.audio_dit: MovaAudioDit = None
|
|
self.dual_tower_bridge: DualTowerConditionalBridge = None
|
|
self.video_vae: WanVideoVAE = None
|
|
self.audio_vae: DacVAE = None
|
|
|
|
self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge")
|
|
self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge")
|
|
|
|
self.units = [
|
|
MovaAudioVideoUnit_ShapeChecker(),
|
|
MovaAudioVideoUnit_NoiseInitializer(),
|
|
MovaAudioVideoUnit_InputVideoEmbedder(),
|
|
MovaAudioVideoUnit_InputAudioEmbedder(),
|
|
MovaAudioVideoUnit_PromptEmbedder(),
|
|
MovaAudioVideoUnit_ImageEmbedderVAE(),
|
|
MovaAudioVideoUnit_UnifiedSequenceParallel(),
|
|
]
|
|
self.model_fn = model_fn_mova_audio_video
|
|
|
|
def enable_usp(self):
|
|
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
|
|
for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:
|
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
|
self.sp_size = get_sequence_parallel_world_size()
|
|
self.use_unified_sequence_parallel = True
|
|
|
|
@staticmethod
|
|
def from_pretrained(
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
device: Union[str, torch.device] = get_device_type(),
|
|
model_configs: list[ModelConfig] = [],
|
|
tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
|
use_usp: bool = False,
|
|
vram_limit: float = None,
|
|
):
|
|
if use_usp:
|
|
from ..utils.xfuser import initialize_usp
|
|
initialize_usp(device)
|
|
import torch.distributed as dist
|
|
from ..core.device.npu_compatible_device import get_device_name
|
|
if dist.is_available() and dist.is_initialized():
|
|
device = get_device_name()
|
|
# Initialize pipeline
|
|
pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
|
|
|
# Fetch models
|
|
pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
|
|
dit = model_pool.fetch_model("wan_video_dit", index=2)
|
|
if isinstance(dit, list):
|
|
pipe.video_dit, pipe.video_dit2 = dit
|
|
else:
|
|
pipe.video_dit = dit
|
|
pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
|
|
pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
|
|
pipe.video_vae = model_pool.fetch_model("wan_video_vae")
|
|
pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
|
|
set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))
|
|
|
|
# Size division factor
|
|
if pipe.video_vae is not None:
|
|
pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2
|
|
pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2
|
|
|
|
# Initialize tokenizer and processor
|
|
if tokenizer_config is not None:
|
|
tokenizer_config.download_if_necessary()
|
|
pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')
|
|
|
|
# Unified Sequence Parallel
|
|
if use_usp: pipe.enable_usp()
|
|
|
|
# VRAM Management
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
return pipe
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
# Prompt
|
|
prompt: str,
|
|
negative_prompt: Optional[str] = "",
|
|
# Image-to-video
|
|
input_image: Optional[Image.Image] = None,
|
|
# First-last-frame-to-video
|
|
end_image: Optional[Image.Image] = None,
|
|
# Video-to-video
|
|
denoising_strength: Optional[float] = 1.0,
|
|
# Randomness
|
|
seed: Optional[int] = None,
|
|
rand_device: Optional[str] = "cpu",
|
|
# Shape
|
|
height: Optional[int] = 352,
|
|
width: Optional[int] = 640,
|
|
num_frames: Optional[int] = 81,
|
|
frame_rate: Optional[int] = 24,
|
|
# Classifier-free guidance
|
|
cfg_scale: Optional[float] = 5.0,
|
|
# Boundary
|
|
switch_DiT_boundary: Optional[float] = 0.9,
|
|
# Scheduler
|
|
num_inference_steps: Optional[int] = 50,
|
|
sigma_shift: Optional[float] = 5.0,
|
|
# VAE tiling
|
|
tiled: Optional[bool] = True,
|
|
tile_size: Optional[tuple[int, int]] = (30, 52),
|
|
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
|
# progress_bar
|
|
progress_bar_cmd=tqdm,
|
|
):
|
|
# Scheduler
|
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
|
|
|
# Inputs
|
|
inputs_posi = {
|
|
"prompt": prompt,
|
|
}
|
|
inputs_nega = {
|
|
"negative_prompt": negative_prompt,
|
|
}
|
|
inputs_shared = {
|
|
"input_image": input_image,
|
|
"end_image": end_image,
|
|
"denoising_strength": denoising_strength,
|
|
"seed": seed, "rand_device": rand_device,
|
|
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
|
"cfg_scale": cfg_scale,
|
|
"sigma_shift": sigma_shift,
|
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
|
}
|
|
for unit in self.units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
|
|
|
# Denoise
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
# Switch DiT if necessary
|
|
if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2:
|
|
self.load_models_to_device(self.in_iteration_models_2)
|
|
models["video_dit"] = self.video_dit2
|
|
# Timestep
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
|
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
# Scheduler
|
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
|
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
|
|
|
# Decode
|
|
self.load_models_to_device(['video_vae'])
|
|
video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
video = self.vae_output_to_video(video)
|
|
self.load_models_to_device(["audio_vae"])
|
|
audio = self.audio_vae.decode(inputs_shared["audio_latents"])
|
|
audio = self.output_audio_format_check(audio)
|
|
self.load_models_to_device([])
|
|
return video, audio
|
|
|
|
|
|
class MovaAudioVideoUnit_ShapeChecker(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width", "num_frames"),
|
|
output_params=("height", "width", "num_frames"),
|
|
)
|
|
|
|
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):
|
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
|
return {"height": height, "width": width, "num_frames": num_frames}
|
|
|
|
|
|
class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
|
|
output_params=("video_noise", "audio_noise")
|
|
)
|
|
|
|
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):
|
|
length = (num_frames - 1) // 4 + 1
|
|
video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)
|
|
video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)
|
|
|
|
audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1
|
|
audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)
|
|
audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)
|
|
return {"video_noise": video_noise, "audio_noise": audio_noise}
|
|
|
|
|
|
class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"),
|
|
output_params=("video_latents", "input_latents"),
|
|
onload_model_names=("video_vae",)
|
|
)
|
|
|
|
def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
|
|
if input_video is None or not pipe.scheduler.training:
|
|
return {"video_latents": video_noise}
|
|
else:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_video = pipe.preprocess_video(input_video)
|
|
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
return {"input_latents": input_latents}
|
|
|
|
|
|
class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_audio", "audio_noise"),
|
|
output_params=("audio_latents", "audio_input_latents"),
|
|
onload_model_names=("audio_vae",)
|
|
)
|
|
|
|
def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
|
|
if input_audio is None or not pipe.scheduler.training:
|
|
return {"audio_latents": audio_noise}
|
|
else:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_audio, sample_rate = input_audio
|
|
input_audio = convert_to_mono(input_audio)
|
|
input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
|
|
input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
|
|
z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
|
|
return {"audio_input_latents": z.mode()}
|
|
|
|
|
|
class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params_posi={"prompt": "prompt"},
|
|
input_params_nega={"prompt": "negative_prompt"},
|
|
output_params=("context",),
|
|
onload_model_names=("text_encoder",)
|
|
)
|
|
|
|
def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):
|
|
ids, mask = pipe.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=512,
|
|
truncation=True,
|
|
add_special_tokens=True,
|
|
return_mask=True,
|
|
return_tensors="pt",
|
|
)
|
|
ids = ids.to(pipe.device)
|
|
mask = mask.to(pipe.device)
|
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
|
prompt_emb = pipe.text_encoder(ids, mask)
|
|
for i, v in enumerate(seq_lens):
|
|
prompt_emb[:, v:] = 0
|
|
return prompt_emb
|
|
|
|
def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
prompt_emb = self.encode_prompt(pipe, prompt)
|
|
return {"context": prompt_emb}
|
|
|
|
|
|
class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
|
output_params=("y",),
|
|
onload_model_names=("video_vae",)
|
|
)
|
|
|
|
def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
|
if input_image is None or not pipe.video_dit.require_vae_embedding:
|
|
return {}
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
|
|
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
|
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
|
msk[:, 1:] = 0
|
|
if end_image is not None:
|
|
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
|
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
|
msk[:, -1:] = 1
|
|
else:
|
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
|
|
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
|
msk = msk.transpose(1, 2)[0]
|
|
|
|
y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
y = torch.concat([msk, y])
|
|
y = y.unsqueeze(0)
|
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
return {"y": y}
|
|
|
|
|
|
class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))
|
|
|
|
def process(self, pipe: MovaAudioVideoPipeline):
|
|
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
|
|
return {"use_unified_sequence_parallel": True}
|
|
return {"use_unified_sequence_parallel": False}
|
|
|
|
|
|
def model_fn_mova_audio_video(
|
|
video_dit: WanModel,
|
|
audio_dit: MovaAudioDit,
|
|
dual_tower_bridge: DualTowerConditionalBridge,
|
|
video_latents: torch.Tensor = None,
|
|
audio_latents: torch.Tensor = None,
|
|
timestep: torch.Tensor = None,
|
|
context: torch.Tensor = None,
|
|
y: Optional[torch.Tensor] = None,
|
|
frame_rate: Optional[int] = 24,
|
|
use_unified_sequence_parallel: bool = False,
|
|
use_gradient_checkpointing: bool = False,
|
|
use_gradient_checkpointing_offload: bool = False,
|
|
**kwargs,
|
|
):
|
|
video_x, audio_x = video_latents, audio_latents
|
|
# First-Last Frame
|
|
if y is not None:
|
|
video_x = torch.cat([video_x, y], dim=1)
|
|
|
|
# Timestep
|
|
video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))
|
|
video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))
|
|
audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))
|
|
audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))
|
|
|
|
# Context
|
|
video_context = video_dit.text_embedding(context)
|
|
audio_context = audio_dit.text_embedding(context)
|
|
|
|
# Patchify
|
|
video_x = video_dit.patch_embedding(video_x)
|
|
f_v, h, w = video_x.shape[2:]
|
|
video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()
|
|
seq_len_video = video_x.shape[1]
|
|
|
|
audio_x = audio_dit.patch_embedding(audio_x)
|
|
f_a = audio_x.shape[2]
|
|
audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()
|
|
seq_len_audio = audio_x.shape[1]
|
|
|
|
# Freqs
|
|
video_freqs = torch.cat([
|
|
video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),
|
|
video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),
|
|
video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)
|
|
], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)
|
|
audio_freqs = torch.cat([
|
|
audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),
|
|
audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),
|
|
audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),
|
|
], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)
|
|
|
|
video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(
|
|
video_fps=frame_rate,
|
|
grid_size=(f_v, h, w),
|
|
audio_steps=audio_x.shape[1],
|
|
device=video_x.device,
|
|
dtype=video_x.dtype,
|
|
)
|
|
# usp func
|
|
if use_unified_sequence_parallel:
|
|
from ..utils.xfuser import get_current_chunk, gather_all_chunks
|
|
else:
|
|
get_current_chunk = lambda x, dim=1: x
|
|
gather_all_chunks = lambda x, seq_len, dim=1: x
|
|
# Forward blocks
|
|
for block_id in range(len(audio_dit.blocks)):
|
|
if dual_tower_bridge.should_interact(block_id, "a2v"):
|
|
video_x, audio_x = dual_tower_bridge(
|
|
block_id,
|
|
video_x,
|
|
audio_x,
|
|
x_freqs=video_rope,
|
|
y_freqs=audio_rope,
|
|
condition_scale=1.0,
|
|
video_grid_size=(f_v, h, w),
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
)
|
|
video_x = get_current_chunk(video_x, dim=1)
|
|
video_x = gradient_checkpoint_forward(
|
|
video_dit.blocks[block_id],
|
|
use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload,
|
|
video_x, video_context, video_t_mod, video_freqs
|
|
)
|
|
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
|
audio_x = get_current_chunk(audio_x, dim=1)
|
|
audio_x = gradient_checkpoint_forward(
|
|
audio_dit.blocks[block_id],
|
|
use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload,
|
|
audio_x, audio_context, audio_t_mod, audio_freqs
|
|
)
|
|
audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)
|
|
|
|
video_x = get_current_chunk(video_x, dim=1)
|
|
for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):
|
|
video_x = gradient_checkpoint_forward(
|
|
video_dit.blocks[block_id],
|
|
use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload,
|
|
video_x, video_context, video_t_mod, video_freqs
|
|
)
|
|
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
|
|
|
# Head
|
|
video_x = video_dit.head(video_x, video_t)
|
|
video_x = video_dit.unpatchify(video_x, (f_v, h, w))
|
|
|
|
audio_x = audio_dit.head(audio_x, audio_t)
|
|
audio_x = audio_dit.unpatchify(audio_x, (f_a,))
|
|
return video_x, audio_x
|