mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Mova (#1337)
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
This commit is contained in:
@@ -22,6 +22,7 @@ from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Voco
|
||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||
from ..utils.data.audio import convert_to_stereo
|
||||
|
||||
|
||||
class LTX2AudioVideoPipeline(BasePipeline):
|
||||
@@ -389,6 +390,7 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
return {"audio_latents": audio_noise}
|
||||
else:
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = convert_to_stereo(input_audio)
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
|
||||
audio_input_latents = pipe.audio_vae_encoder(input_audio)
|
||||
@@ -441,6 +443,7 @@ class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):
|
||||
return {}
|
||||
else:
|
||||
input_audio, sample_rate = retake_audio
|
||||
input_audio = convert_to_stereo(input_audio)
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
input_latents_audio = pipe.audio_vae_encoder(input_audio)
|
||||
|
||||
460
diffsynth/pipelines/mova_audio_video.py
Normal file
460
diffsynth/pipelines/mova_audio_video.py
Normal file
@@ -0,0 +1,460 @@
|
||||
import sys
|
||||
import torch, types
|
||||
from PIL import Image
|
||||
from typing import Optional, Union
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.mova_audio_dit import MovaAudioDit
|
||||
from ..models.mova_audio_vae import DacVAE
|
||||
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
|
||||
from ..utils.data.audio import convert_to_mono, resample_waveform
|
||||
|
||||
|
||||
class MovaAudioVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("Wan")
|
||||
self.tokenizer: HuggingfaceTokenizer = None
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
self.video_dit: WanModel = None # high noise model
|
||||
self.video_dit2: WanModel = None # low noise model
|
||||
self.audio_dit: MovaAudioDit = None
|
||||
self.dual_tower_bridge: DualTowerConditionalBridge = None
|
||||
self.video_vae: WanVideoVAE = None
|
||||
self.audio_vae: DacVAE = None
|
||||
|
||||
self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge")
|
||||
self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge")
|
||||
|
||||
self.units = [
|
||||
MovaAudioVideoUnit_ShapeChecker(),
|
||||
MovaAudioVideoUnit_NoiseInitializer(),
|
||||
MovaAudioVideoUnit_InputVideoEmbedder(),
|
||||
MovaAudioVideoUnit_InputAudioEmbedder(),
|
||||
MovaAudioVideoUnit_PromptEmbedder(),
|
||||
MovaAudioVideoUnit_ImageEmbedderVAE(),
|
||||
MovaAudioVideoUnit_UnifiedSequenceParallel(),
|
||||
]
|
||||
self.model_fn = model_fn_mova_audio_video
|
||||
|
||||
def enable_usp(self):
|
||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
|
||||
for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
use_usp: bool = False,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
if use_usp:
|
||||
from ..utils.xfuser import initialize_usp
|
||||
initialize_usp(device)
|
||||
import torch.distributed as dist
|
||||
from ..core.device.npu_compatible_device import get_device_name
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
device = get_device_name()
|
||||
# Initialize pipeline
|
||||
pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
|
||||
dit = model_pool.fetch_model("wan_video_dit", index=2)
|
||||
if isinstance(dit, list):
|
||||
pipe.video_dit, pipe.video_dit2 = dit
|
||||
else:
|
||||
pipe.video_dit = dit
|
||||
pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
|
||||
pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
|
||||
pipe.video_vae = model_pool.fetch_model("wan_video_vae")
|
||||
pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
|
||||
set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))
|
||||
|
||||
# Size division factor
|
||||
if pipe.video_vae is not None:
|
||||
pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2
|
||||
pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2
|
||||
|
||||
# Initialize tokenizer and processor
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')
|
||||
|
||||
# Unified Sequence Parallel
|
||||
if use_usp: pipe.enable_usp()
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
input_image: Optional[Image.Image] = None,
|
||||
# First-last-frame-to-video
|
||||
end_image: Optional[Image.Image] = None,
|
||||
# Video-to-video
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height: Optional[int] = 352,
|
||||
width: Optional[int] = 640,
|
||||
num_frames: Optional[int] = 81,
|
||||
frame_rate: Optional[int] = 24,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
# Boundary
|
||||
switch_DiT_boundary: Optional[float] = 0.9,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_image": input_image,
|
||||
"end_image": end_image,
|
||||
"denoising_strength": denoising_strength,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
||||
"cfg_scale": cfg_scale,
|
||||
"sigma_shift": sigma_shift,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
# Switch DiT if necessary
|
||||
if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2:
|
||||
self.load_models_to_device(self.in_iteration_models_2)
|
||||
models["video_dit"] = self.video_dit2
|
||||
# Timestep
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
# Scheduler
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae'])
|
||||
video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(["audio_vae"])
|
||||
audio = self.audio_vae.decode(inputs_shared["audio_latents"])
|
||||
audio = self.output_audio_format_check(audio)
|
||||
self.load_models_to_device([])
|
||||
return video, audio
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames"),
|
||||
output_params=("height", "width", "num_frames"),
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):
|
||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||
return {"height": height, "width": width, "num_frames": num_frames}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
|
||||
output_params=("video_noise", "audio_noise")
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):
|
||||
length = (num_frames - 1) // 4 + 1
|
||||
video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)
|
||||
video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)
|
||||
|
||||
audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1
|
||||
audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)
|
||||
audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)
|
||||
return {"video_noise": video_noise, "audio_noise": audio_noise}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("video_latents", "input_latents"),
|
||||
onload_model_names=("video_vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
|
||||
if input_video is None or not pipe.scheduler.training:
|
||||
return {"video_latents": video_noise}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"input_latents": input_latents}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_audio", "audio_noise"),
|
||||
output_params=("audio_latents", "audio_input_latents"),
|
||||
onload_model_names=("audio_vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
|
||||
if input_audio is None or not pipe.scheduler.training:
|
||||
return {"audio_latents": audio_noise}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = convert_to_mono(input_audio)
|
||||
input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
|
||||
input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
|
||||
z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
|
||||
return {"audio_input_latents": z.mode()}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("context",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):
|
||||
ids, mask = pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=512,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
ids = ids.to(pipe.device)
|
||||
mask = mask.to(pipe.device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_emb = pipe.text_encoder(ids, mask)
|
||||
for i, v in enumerate(seq_lens):
|
||||
prompt_emb[:, v:] = 0
|
||||
return prompt_emb
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_emb = self.encode_prompt(pipe, prompt)
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("y",),
|
||||
onload_model_names=("video_vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or not pipe.video_dit.require_vae_embedding:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"y": y}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline):
|
||||
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
|
||||
return {"use_unified_sequence_parallel": True}
|
||||
return {"use_unified_sequence_parallel": False}
|
||||
|
||||
|
||||
def model_fn_mova_audio_video(
|
||||
video_dit: WanModel,
|
||||
audio_dit: MovaAudioDit,
|
||||
dual_tower_bridge: DualTowerConditionalBridge,
|
||||
video_latents: torch.Tensor = None,
|
||||
audio_latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
frame_rate: Optional[int] = 24,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
video_x, audio_x = video_latents, audio_latents
|
||||
# First-Last Frame
|
||||
if y is not None:
|
||||
video_x = torch.cat([video_x, y], dim=1)
|
||||
|
||||
# Timestep
|
||||
video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))
|
||||
video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))
|
||||
audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))
|
||||
audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))
|
||||
|
||||
# Context
|
||||
video_context = video_dit.text_embedding(context)
|
||||
audio_context = audio_dit.text_embedding(context)
|
||||
|
||||
# Patchify
|
||||
video_x = video_dit.patch_embedding(video_x)
|
||||
f_v, h, w = video_x.shape[2:]
|
||||
video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
seq_len_video = video_x.shape[1]
|
||||
|
||||
audio_x = audio_dit.patch_embedding(audio_x)
|
||||
f_a = audio_x.shape[2]
|
||||
audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()
|
||||
seq_len_audio = audio_x.shape[1]
|
||||
|
||||
# Freqs
|
||||
video_freqs = torch.cat([
|
||||
video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),
|
||||
video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),
|
||||
video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)
|
||||
], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)
|
||||
audio_freqs = torch.cat([
|
||||
audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||
audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||
audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||
], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)
|
||||
|
||||
video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(
|
||||
video_fps=frame_rate,
|
||||
grid_size=(f_v, h, w),
|
||||
audio_steps=audio_x.shape[1],
|
||||
device=video_x.device,
|
||||
dtype=video_x.dtype,
|
||||
)
|
||||
# usp func
|
||||
if use_unified_sequence_parallel:
|
||||
from ..utils.xfuser import get_current_chunk, gather_all_chunks
|
||||
else:
|
||||
get_current_chunk = lambda x, dim=1: x
|
||||
gather_all_chunks = lambda x, seq_len, dim=1: x
|
||||
# Forward blocks
|
||||
for block_id in range(len(audio_dit.blocks)):
|
||||
if dual_tower_bridge.should_interact(block_id, "a2v"):
|
||||
video_x, audio_x = dual_tower_bridge(
|
||||
block_id,
|
||||
video_x,
|
||||
audio_x,
|
||||
x_freqs=video_rope,
|
||||
y_freqs=audio_rope,
|
||||
condition_scale=1.0,
|
||||
video_grid_size=(f_v, h, w),
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
video_x = get_current_chunk(video_x, dim=1)
|
||||
video_x = gradient_checkpoint_forward(
|
||||
video_dit.blocks[block_id],
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
video_x, video_context, video_t_mod, video_freqs
|
||||
)
|
||||
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
||||
audio_x = get_current_chunk(audio_x, dim=1)
|
||||
audio_x = gradient_checkpoint_forward(
|
||||
audio_dit.blocks[block_id],
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
audio_x, audio_context, audio_t_mod, audio_freqs
|
||||
)
|
||||
audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)
|
||||
|
||||
video_x = get_current_chunk(video_x, dim=1)
|
||||
for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):
|
||||
video_x = gradient_checkpoint_forward(
|
||||
video_dit.blocks[block_id],
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
video_x, video_context, video_t_mod, video_freqs
|
||||
)
|
||||
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
||||
|
||||
# Head
|
||||
video_x = video_dit.head(video_x, video_t)
|
||||
video_x = video_dit.unpatchify(video_x, (f_v, h, w))
|
||||
|
||||
audio_x = audio_dit.head(audio_x, audio_t)
|
||||
audio_x = audio_dit.unpatchify(audio_x, (f_a,))
|
||||
return video_x, audio_x
|
||||
Reference in New Issue
Block a user