* support mova inference

* mova media_io

* add unified audio_video api & fix bug of mono audio input for ltx

* support mova train

* mova docs

* fix bug
This commit is contained in:
Hong Zhang
2026-03-13 13:06:07 +08:00
committed by GitHub
parent 4741542523
commit 681df93a85
37 changed files with 3102 additions and 181 deletions

View File

@@ -32,8 +32,9 @@ We believe that a well-developed open-source code framework can lower the thresh
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **January 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/zh/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
- **March 3, 2026**: We released the [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) model, which is an updated version of Qwen-Image-Layered-Control. In addition to the originally supported text-guided functionality, it adds brush-controlled layer separation capabilities.
@@ -867,6 +868,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
</details>

View File

@@ -32,8 +32,10 @@ DiffSynth 目前包括两个开源项目:
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年1月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
- **2026年3月3日** 我们发布了 [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) 模型,这是 Qwen-Image-Layered-Control 的更新版本。除了原本就支持的文本引导功能,新增了画笔控制的图层拆分能力。
- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。
@@ -866,6 +868,8 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
</details>

View File

@@ -848,4 +848,26 @@ anima_series = [
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
}
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series
mova_series = [
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors")
{
"model_hash": "8c57e12790e2c45a64817e0ce28cde2f",
"model_name": "mova_audio_dit",
"model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors")
{
"model_hash": "418517fb2b4e919d2cac8f314fcf82ac",
"model_name": "mova_audio_vae",
"model_class": "diffsynth.models.mova_audio_vae.DacVAE",
},
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors")
{
"model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb",
"model_name": "mova_dual_tower_bridge",
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series

View File

@@ -249,6 +249,24 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.mova_audio_dit.MovaAudioDit": {
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.mova_audio_vae.DacVAE": {
"diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():

View File

@@ -152,13 +152,6 @@ class BasePipeline(torch.nn.Module):
# remove batch dim
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
# Transform to stereo
if audio_output.shape[0] == 1:
audio_output = audio_output.repeat(2, 1)
elif audio_output.shape[0] == 2:
pass
else:
raise ValueError("The output audio should be [C, T] or [1, C, T] or [2, C, T].")
return audio_output.float()
def load_models_to_device(self, model_names):

View File

@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d
from einops import rearrange
from ..core import gradient_checkpoint_forward
def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim, end, theta)
return f_freqs_cis.chunk(3, dim=-1)
class MovaAudioDit(WanModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12)
self.freqs = precompute_freqs_cis_1d(head_dim)
self.patch_embedding = nn.Conv1d(
kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1]
)
def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):
self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
x, (f, ) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, -1).expand(f, -1),
self.freqs[1][:f].view(f, -1).expand(f, -1),
self.freqs[2][:f].view(f, -1).expand(f, -1),
], dim=-1).reshape(f, 1, -1).to(x.device)
for block in self.blocks:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs,
)
x = self.head(x, t)
x = self.unpatchify(x, (f, ))
return x
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b f (p c) -> b c (f p)',
f=grid_size[0],
p=self.patch_size[0]
)

View File

@@ -0,0 +1,796 @@
import math
from typing import List, Union
import numpy as np
import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
from einops import rearrange
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
class VectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
Additionally uses following tricks from Improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
self.codebook = nn.Embedding(codebook_size, codebook_dim)
def forward(self, z):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
z_e = self.in_proj(z) # z_e : (B x D x T)
z_q, indices = self.decode_latents(z_e)
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
z_q = self.out_proj(z_q)
return z_q, commitment_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
class ResidualVectorQuantize(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: float = 0.0,
):
super().__init__()
if isinstance(codebook_dim, int):
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
self.n_codebooks = n_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.quantizers = nn.ModuleList(
[
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
for i in range(n_codebooks)
]
)
self.quantizer_dropout = quantizer_dropout
def forward(self, z, n_quantizers: int = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q = 0
residual = z
commitment_loss = 0
codebook_loss = 0
codebook_indices = []
latents = []
if n_quantizers is None:
n_quantizers = self.n_codebooks
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = (
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
)
z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i
# Sum losses
commitment_loss += (commitment_loss_i * mask).mean()
codebook_loss += (codebook_loss_i * mask).mean()
codebook_indices.append(indices_i)
latents.append(z_e_i)
codes = torch.stack(codebook_indices, dim=1)
latents = torch.cat(latents, dim=1)
return z_q, codes, latents, commitment_loss, codebook_loss
def from_codes(self, codes: torch.Tensor):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q = 0.0
z_p = []
n_codebooks = codes.shape[1]
for i in range(n_codebooks):
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
z_p.append(z_p_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), codes
def from_latents(self, latents: torch.Tensor):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q = 0
z_p = []
codes = []
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
0
]
for i in range(n_codebooks):
j, k = dims[i], dims[i + 1]
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
z_p.append(z_p_i)
codes.append(codes_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.mean(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2],
)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2],
)
def nll(self, sample, dims=[1, 2]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
return 0.5 * (
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
):
super().__init__()
# Add first conv layer
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DacVAE(nn.Module):
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 3, 4, 5, 8],
latent_dim: int = 128,
decoder_dim: int = 2048,
decoder_rates: List[int] = [8, 5, 4, 3, 2],
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: bool = False,
sample_rate: int = 48000,
continuous: bool = True,
use_weight_norm: bool = False,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
self.continuous = continuous
self.use_weight_norm = use_weight_norm
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
if not continuous:
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer = ResidualVectorQuantize(
input_dim=latent_dim,
n_codebooks=n_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
else:
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
)
self.sample_rate = sample_rate
self.apply(init_weights)
self.delay = self.get_delay()
if not self.use_weight_norm:
self.remove_weight_norm()
def get_delay(self):
# Any number works here, delay is invariant to input length
l_out = self.get_output_length(0)
L = l_out
layers = []
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
layers.append(layer)
for layer in reversed(layers):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.ConvTranspose1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.Conv1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.ceil(L)
l_in = L
return (l_in - l_out) // 2
def get_output_length(self, input_length):
L = input_length
# Calculate output length
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.Conv1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.ConvTranspose1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.floor(L)
return L
@property
def dtype(self):
"""Get the dtype of the model parameters."""
# Return the dtype of the first parameter found
for param in self.parameters():
return param.dtype
return torch.float32 # fallback
@property
def device(self):
"""Get the device of the model parameters."""
# Return the device of the first parameter found
for param in self.parameters():
return param.device
return torch.device('cpu') # fallback
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def encode(
self,
audio_data: torch.Tensor,
n_quantizers: int = None,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
z = self.encoder(audio_data) # [B x D x T]
if not self.continuous:
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
else:
z = self.quant_conv(z) # [B x 2D x T]
z = DiagonalGaussianDistribution(z)
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
return z, codes, latents, commitment_loss, codebook_loss
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
if not self.continuous:
audio = self.decoder(z)
else:
z = self.post_quant_conv(z)
audio = self.decoder(z)
return audio
def forward(
self,
audio_data: torch.Tensor,
sample_rate: int = None,
n_quantizers: int = None,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data, sample_rate)
if not self.continuous:
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
x = self.decode(z)
return {
"audio": x[..., :length],
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
else:
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
z = posterior.sample()
x = self.decode(z)
kl_loss = posterior.kl()
kl_loss = kl_loss.mean()
return {
"audio": x[..., :length],
"z": z,
"kl_loss": kl_loss,
}
def remove_weight_norm(self):
"""
Remove weight_norm from all modules in the model.
This fuses the weight_g and weight_v parameters into a single weight parameter.
Should be called before inference for better performance.
Returns:
self: The model with weight_norm removed
"""
from torch.nn.utils import remove_weight_norm
num_removed = 0
for name, module in list(self.named_modules()):
if hasattr(module, "_forward_pre_hooks"):
for hook_id, hook in list(module._forward_pre_hooks.items()):
if "WeightNorm" in str(type(hook)):
try:
remove_weight_norm(module)
num_removed += 1
# print(f"Removed weight_norm from: {name}")
except ValueError as e:
print(f"Failed to remove weight_norm from {name}: {e}")
if num_removed > 0:
# print(f"Successfully removed weight_norm from {num_removed} modules")
self.use_weight_norm = False
else:
print("No weight_norm found in the model")
return self

View File

@@ -0,0 +1,595 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
from einops import rearrange
from .wan_video_dit import AttentionModule, RMSNorm
from ..core import gradient_checkpoint_forward
class RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, base: float, dim: int, device=None):
super().__init__()
self.base = base
self.dim = dim
self.attention_scaling = 1.0
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@torch.compile(fullgraph=True)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PerFrameAttentionPooling(nn.Module):
"""
Per-frame multi-head attention pooling.
Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a
single-query attention pooling over the H*W tokens for each time frame, producing
[B, T, D].
Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack).
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.probe = nn.Parameter(torch.randn(1, 1, dim))
nn.init.normal_(self.probe, std=0.02)
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
self.layernorm = nn.LayerNorm(dim, eps=eps)
def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:
"""
Args:
x: [B, L, D], where L = T*H*W
grid_size: (T, H, W)
Returns:
pooled: [B, T, D]
"""
B, L, D = x.shape
T, H, W = grid_size
assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}"
assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}"
S = H * W
# Re-arrange tokens grouped by frame.
x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D]
# A learnable probe as the query (one query per frame).
probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D]
# Attention pooling: query=probe, key/value=H*W tokens within the frame.
pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D]
pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D]
# Restore to [B, T, D].
pooled = pooled_bt_d.view(B, T, D)
pooled = self.layernorm(pooled)
return pooled
class CrossModalInteractionController:
"""
Strategy class that controls interactions between two towers.
Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers).
"""
def __init__(self, visual_layers: int = 30, audio_layers: int = 30):
self.visual_layers = visual_layers
self.audio_layers = audio_layers
self.min_layers = min(visual_layers, audio_layers)
def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]:
"""
Get interaction layer mappings.
Args:
strategy: interaction strategy
- "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry
- "distributed": distributed interactions across the network
- "progressive": dense shallow interactions, sparse deeper interactions
- "custom": custom interaction layers
Returns:
A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual).
"""
if strategy == "shallow_focus":
# Emphasize the first ~1/3 layers to avoid deep-layer asymmetry.
num_interact = min(10, self.min_layers // 3)
interact_layers = list(range(0, num_interact))
elif strategy == "distributed":
# Distribute interactions across the network (every few layers).
step = 3
interact_layers = list(range(0, self.min_layers, step))
elif strategy == "progressive":
# Progressive: dense shallow interactions, sparse deeper interactions.
shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers.
if self.min_layers > 8:
deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards.
interact_layers = shallow + deep
else:
interact_layers = shallow
elif strategy == "custom":
# Custom strategy: adjust as needed.
interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices.
interact_layers = [i for i in interact_layers if i < self.min_layers]
elif strategy == "full":
interact_layers = list(range(0, self.min_layers))
else:
raise ValueError(f"Unknown interaction strategy: {strategy}")
# Build bidirectional mapping.
mapping = {
'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i
'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i
}
return mapping
def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool:
"""
Check whether a given layer should interact.
Args:
layer_idx: current layer index
direction: interaction direction ('v2a' or 'a2v')
interaction_mapping: interaction mapping table
Returns:
bool: whether to interact
"""
if direction not in interaction_mapping:
return False
return any(src == layer_idx for src, _ in interaction_mapping[direction])
class ConditionalCrossAttention(nn.Module):
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.q_dim = dim
self.kv_dim = kv_dim
self.num_heads = num_heads
self.head_dim = self.q_dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(kv_dim, dim)
self.v = nn.Linear(kv_dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
if x_freqs is not None:
x_cos, x_sin = x_freqs
B, L, _ = q.shape
q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim)
x_cos = x_cos.to(q_view.dtype).to(q_view.device)
x_sin = x_sin.to(q_view.dtype).to(q_view.device)
# Expect x_cos/x_sin shape: [B or 1, L, head_dim]
q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2)
q = rearrange(q_view, 'b l h d -> b l (h d)')
if y_freqs is not None:
y_cos, y_sin = y_freqs
Bc, Lc, _ = k.shape
k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim)
y_cos = y_cos.to(k_view.dtype).to(k_view.device)
y_sin = y_sin.to(k_view.dtype).to(k_view.device)
# Expect y_cos/y_sin shape: [B or 1, L, head_dim]
_, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2)
k = rearrange(k_view, 'b l h d -> b l (h d)')
x = self.attn(q, k, v)
return self.o(x)
# from diffusers.models.attention import AdaLayerNorm
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 2:
scale, shift = temb.chunk(2, dim=2)
# print(f"{x.shape = }, {scale.shape = }, {shift.shape = }")
elif self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
x = self.norm(x) * (1 + scale) + shift
return x
class ConditionalCrossAttentionBlock(nn.Module):
"""
A thin wrapper around ConditionalCrossAttention.
Applies LayerNorm to the conditioning input `y` before cross-attention.
"""
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False):
super().__init__()
self.y_norm = nn.LayerNorm(kv_dim, eps=eps)
self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps)
self.pooled_adaln = pooled_adaln
if pooled_adaln:
self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps)
self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2)
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
) -> torch.Tensor:
if self.pooled_adaln:
assert video_grid_size is not None, "video_grid_size must not be None"
pooled_y = self.per_frame_pooling(y, video_grid_size)
# Interpolate pooled_y along its temporal dimension to match x's sequence length.
if pooled_y.shape[1] != x.shape[1]:
pooled_y = F.interpolate(
pooled_y.permute(0, 2, 1), # [B, C, T]
size=x.shape[1],
mode='linear',
align_corners=False,
).permute(0, 2, 1) # [B, T, C]
x = self.adaln(x, temb=pooled_y)
y = self.y_norm(y)
return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)
class DualTowerConditionalBridge(nn.Module):
"""
Dual-tower conditional bridge.
"""
def __init__(self,
visual_layers: int = 40,
audio_layers: int = 30,
visual_hidden_dim: int = 5120, # visual DiT hidden state dimension
audio_hidden_dim: int = 1536, # audio DiT hidden state dimension
audio_fps: float = 50.0,
head_dim: int = 128, # attention head dimension
interaction_strategy: str = "full",
apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention
apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment
trainable_condition_scale: bool = False,
pooled_adaln: bool = False,
):
super().__init__()
self.visual_hidden_dim = visual_hidden_dim
self.audio_hidden_dim = audio_hidden_dim
self.audio_fps = audio_fps
self.head_dim = head_dim
self.apply_cross_rope = apply_cross_rope
self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope
self.trainable_condition_scale = trainable_condition_scale
self.pooled_adaln = pooled_adaln
if self.trainable_condition_scale:
self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32))
else:
self.condition_scale = 1.0
self.controller = CrossModalInteractionController(visual_layers, audio_layers)
self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy)
# Conditional cross-attention modules operating at the DiT hidden-state level.
self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning
self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning
# Build conditioners for layers that should interact.
# audio hidden states condition the visual DiT
self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim)
for v_layer, _ in self.interaction_mapping['a2v']:
self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock(
dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim
pooled_adaln=False # a2v typically does not need pooled AdaLN
)
# visual hidden states condition the audio DiT
for a_layer, _ in self.interaction_mapping['v2a']:
self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock(
dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
num_heads=audio_hidden_dim // head_dim, # safe head count derivation
pooled_adaln=self.pooled_adaln
)
@torch.no_grad()
def build_aligned_freqs(self,
video_fps: float,
grid_size: Tuple[int, int, int],
audio_steps: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""
Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w),
and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048).
Returns:
visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim]
audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim]
"""
f_v, h, w = grid_size
L_v = f_v * h * w
L_a = int(audio_steps)
device = device or next(self.parameters()).device
dtype = dtype or torch.float32
# Audio positions: 0,1,2,...,L_a-1 (audio as reference).
audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)
# Video positions: align video frames to audio-step units.
# FIXME(dhyu): hard-coded VAE temporal stride = 4
if self.apply_first_frame_bias_in_rope:
# Account for the "first frame lasts 1/video_fps" bias.
video_effective_fps = float(video_fps) / 4.0
if f_v > 0:
t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)
if f_v > 1:
t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps)
else:
t_starts = torch.zeros((0,), device=device, dtype=torch.float32)
# Convert to audio-step units.
video_pos_per_frame = t_starts * float(self.audio_fps)
else:
# No first-frame bias: uniform alignment.
scale = float(self.audio_fps) / float(video_fps / 4.0)
video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale
# Flatten to f*h*w; tokens within the same frame share the same time position.
video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)
# print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}")
# print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}")
# Build dummy x to produce cos/sin, dim=head_dim.
dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype)
dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype)
cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos)
cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos)
return (cos_v, sin_v), (cos_a, sin_a)
def should_interact(self, layer_idx: int, direction: str) -> bool:
return self.controller.should_interact(layer_idx, direction, self.interaction_mapping)
def apply_conditional_control(
self,
layer_idx: int,
direction: str,
primary_hidden_states: torch.Tensor,
condition_hidden_states: torch.Tensor,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
condition_scale: Optional[float] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
use_gradient_checkpointing: Optional[bool] = False,
use_gradient_checkpointing_offload: Optional[bool] = False,
) -> torch.Tensor:
"""
Apply conditional control (at the DiT hidden-state level).
Args:
layer_idx: current layer index
direction: conditioning direction
- 'a2v': audio hidden states -> visual DiT
- 'v2a': visual hidden states -> audio DiT
primary_hidden_states: primary DiT hidden states [B, L, hidden_dim]
condition_hidden_states: condition DiT hidden states [B, L, hidden_dim]
condition_scale: conditioning strength (similar to CFG scale)
Returns:
Conditioned primary DiT hidden states [B, L, hidden_dim]
"""
if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping):
return primary_hidden_states
if direction == 'a2v':
# audio hidden states condition the visual DiT
conditioner = self.audio_to_video_conditioners[str(layer_idx)]
elif direction == 'v2a':
# visual hidden states condition the audio DiT
conditioner = self.video_to_audio_conditioners[str(layer_idx)]
else:
raise ValueError(f"Invalid direction: {direction}")
conditioned_features = gradient_checkpoint_forward(
conditioner,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x=primary_hidden_states,
y=condition_hidden_states,
x_freqs=x_freqs,
y_freqs=y_freqs,
video_grid_size=video_grid_size,
)
if self.trainable_condition_scale and condition_scale is not None:
print(
"[WARN] This model has a trainable condition_scale, but an external "
f"condition_scale={condition_scale} was provided. The trainable condition_scale "
"will be ignored in favor of the external value."
)
scale = condition_scale if condition_scale is not None else self.condition_scale
primary_hidden_states = primary_hidden_states + conditioned_features * scale
return primary_hidden_states
def forward(
self,
layer_idx: int,
visual_hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
*,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
a2v_condition_scale: Optional[float] = None,
v2a_condition_scale: Optional[float] = None,
condition_scale: Optional[float] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
use_gradient_checkpointing: Optional[bool] = False,
use_gradient_checkpointing_offload: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply bidirectional conditional control to both visual/audio towers.
Args:
layer_idx: current layer index
visual_hidden_states: visual DiT hidden states
audio_hidden_states: audio DiT hidden states
x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs.
If provided, x_freqs is assumed to correspond to the primary tower and y_freqs
to the conditioning tower.
a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale)
v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale)
condition_scale: fallback conditioning strength when per-direction scale is None
video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled
Returns:
(visual_hidden_states, audio_hidden_states), both conditioned in their respective directions.
"""
visual_conditioned = self.apply_conditional_control(
layer_idx=layer_idx,
direction="a2v",
primary_hidden_states=visual_hidden_states,
condition_hidden_states=audio_hidden_states,
x_freqs=x_freqs,
y_freqs=y_freqs,
condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale,
video_grid_size=video_grid_size,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
audio_conditioned = self.apply_conditional_control(
layer_idx=layer_idx,
direction="v2a",
primary_hidden_states=audio_hidden_states,
condition_hidden_states=visual_hidden_states,
x_freqs=y_freqs,
y_freqs=x_freqs,
condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale,
video_grid_size=video_grid_size,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return visual_conditioned, audio_conditioned

View File

@@ -99,18 +99,30 @@ def rope_apply(x, freqs, num_heads):
return x_out.to(x.dtype)
def set_to_torch_norm(models):
for model in models:
for module in model.modules():
if isinstance(module, RMSNorm):
module.use_torch_norm = True
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.use_torch_norm = False
self.normalized_shape = (dim,)
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
if self.use_torch_norm:
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):

View File

@@ -22,6 +22,7 @@ from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Voco
from ..models.ltx2_upsampler import LTX2LatentUpsampler
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
from ..utils.data.media_io_ltx2 import ltx2_preprocess
from ..utils.data.audio import convert_to_stereo
class LTX2AudioVideoPipeline(BasePipeline):
@@ -389,6 +390,7 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
return {"audio_latents": audio_noise}
else:
input_audio, sample_rate = input_audio
input_audio = convert_to_stereo(input_audio)
pipe.load_models_to_device(self.onload_model_names)
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
audio_input_latents = pipe.audio_vae_encoder(input_audio)
@@ -441,6 +443,7 @@ class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):
return {}
else:
input_audio, sample_rate = retake_audio
input_audio = convert_to_stereo(input_audio)
pipe.load_models_to_device(self.onload_model_names)
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents_audio = pipe.audio_vae_encoder(input_audio)

View File

@@ -0,0 +1,460 @@
import sys
import torch, types
from PIL import Image
from typing import Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
from ..models.wan_video_vae import WanVideoVAE
from ..models.mova_audio_dit import MovaAudioDit
from ..models.mova_audio_vae import DacVAE
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
from ..utils.data.audio import convert_to_mono, resample_waveform
class MovaAudioVideoPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
)
self.scheduler = FlowMatchScheduler("Wan")
self.tokenizer: HuggingfaceTokenizer = None
self.text_encoder: WanTextEncoder = None
self.video_dit: WanModel = None # high noise model
self.video_dit2: WanModel = None # low noise model
self.audio_dit: MovaAudioDit = None
self.dual_tower_bridge: DualTowerConditionalBridge = None
self.video_vae: WanVideoVAE = None
self.audio_vae: DacVAE = None
self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge")
self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge")
self.units = [
MovaAudioVideoUnit_ShapeChecker(),
MovaAudioVideoUnit_NoiseInitializer(),
MovaAudioVideoUnit_InputVideoEmbedder(),
MovaAudioVideoUnit_InputAudioEmbedder(),
MovaAudioVideoUnit_PromptEmbedder(),
MovaAudioVideoUnit_ImageEmbedderVAE(),
MovaAudioVideoUnit_UnifiedSequenceParallel(),
]
self.model_fn = model_fn_mova_audio_video
def enable_usp(self):
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
use_usp: bool = False,
vram_limit: float = None,
):
if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp(device)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name
if dist.is_available() and dist.is_initialized():
device = get_device_name()
# Initialize pipeline
pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
dit = model_pool.fetch_model("wan_video_dit", index=2)
if isinstance(dit, list):
pipe.video_dit, pipe.video_dit2 = dit
else:
pipe.video_dit = dit
pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
pipe.video_vae = model_pool.fetch_model("wan_video_vae")
pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))
# Size division factor
if pipe.video_vae is not None:
pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2
pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2
# Initialize tokenizer and processor
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image: Optional[Image.Image] = None,
# Video-to-video
denoising_strength: Optional[float] = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: Optional[int] = 352,
width: Optional[int] = 640,
num_frames: Optional[int] = 81,
frame_rate: Optional[int] = 24,
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
# Boundary
switch_DiT_boundary: Optional[float] = 0.9,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# VAE tiling
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# progress_bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Inputs
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"input_image": input_image,
"end_image": end_image,
"denoising_strength": denoising_strength,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
"cfg_scale": cfg_scale,
"sigma_shift": sigma_shift,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["video_dit"] = self.video_dit2
# Timestep
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
# Scheduler
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
# Decode
self.load_models_to_device(['video_vae'])
video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
video = self.vae_output_to_video(video)
self.load_models_to_device(["audio_vae"])
audio = self.audio_vae.decode(inputs_shared["audio_latents"])
audio = self.output_audio_format_check(audio)
self.load_models_to_device([])
return video, audio
class MovaAudioVideoUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames"),
output_params=("height", "width", "num_frames"),
)
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
return {"height": height, "width": width, "num_frames": num_frames}
class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
output_params=("video_noise", "audio_noise")
)
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):
length = (num_frames - 1) // 4 + 1
video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)
video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)
audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1
audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)
audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)
return {"video_noise": video_noise, "audio_noise": audio_noise}
class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"),
output_params=("video_latents", "input_latents"),
onload_model_names=("video_vae",)
)
def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
if input_video is None or not pipe.scheduler.training:
return {"video_latents": video_noise}
else:
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"input_latents": input_latents}
class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_audio", "audio_noise"),
output_params=("audio_latents", "audio_input_latents"),
onload_model_names=("audio_vae",)
)
def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
if input_audio is None or not pipe.scheduler.training:
return {"audio_latents": audio_noise}
else:
pipe.load_models_to_device(self.onload_model_names)
input_audio, sample_rate = input_audio
input_audio = convert_to_mono(input_audio)
input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
return {"audio_input_latents": z.mode()}
class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("context",),
onload_model_names=("text_encoder",)
)
def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):
ids, mask = pipe.tokenizer(
prompt,
padding="max_length",
max_length=512,
truncation=True,
add_special_tokens=True,
return_mask=True,
return_tensors="pt",
)
ids = ids.to(pipe.device)
mask = mask.to(pipe.device)
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_emb = pipe.text_encoder(ids, mask)
for i, v in enumerate(seq_lens):
prompt_emb[:, v:] = 0
return prompt_emb
def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:
pipe.load_models_to_device(self.onload_model_names)
prompt_emb = self.encode_prompt(pipe, prompt)
return {"context": prompt_emb}
class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("y",),
onload_model_names=("video_vae",)
)
def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.video_dit.require_vae_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"y": y}
class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
def __init__(self):
super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))
def process(self, pipe: MovaAudioVideoPipeline):
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
return {"use_unified_sequence_parallel": True}
return {"use_unified_sequence_parallel": False}
def model_fn_mova_audio_video(
video_dit: WanModel,
audio_dit: MovaAudioDit,
dual_tower_bridge: DualTowerConditionalBridge,
video_latents: torch.Tensor = None,
audio_latents: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
y: Optional[torch.Tensor] = None,
frame_rate: Optional[int] = 24,
use_unified_sequence_parallel: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
video_x, audio_x = video_latents, audio_latents
# First-Last Frame
if y is not None:
video_x = torch.cat([video_x, y], dim=1)
# Timestep
video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))
video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))
audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))
audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))
# Context
video_context = video_dit.text_embedding(context)
audio_context = audio_dit.text_embedding(context)
# Patchify
video_x = video_dit.patch_embedding(video_x)
f_v, h, w = video_x.shape[2:]
video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()
seq_len_video = video_x.shape[1]
audio_x = audio_dit.patch_embedding(audio_x)
f_a = audio_x.shape[2]
audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()
seq_len_audio = audio_x.shape[1]
# Freqs
video_freqs = torch.cat([
video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),
video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),
video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)
], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)
audio_freqs = torch.cat([
audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),
audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),
audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),
], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)
video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(
video_fps=frame_rate,
grid_size=(f_v, h, w),
audio_steps=audio_x.shape[1],
device=video_x.device,
dtype=video_x.dtype,
)
# usp func
if use_unified_sequence_parallel:
from ..utils.xfuser import get_current_chunk, gather_all_chunks
else:
get_current_chunk = lambda x, dim=1: x
gather_all_chunks = lambda x, seq_len, dim=1: x
# Forward blocks
for block_id in range(len(audio_dit.blocks)):
if dual_tower_bridge.should_interact(block_id, "a2v"):
video_x, audio_x = dual_tower_bridge(
block_id,
video_x,
audio_x,
x_freqs=video_rope,
y_freqs=audio_rope,
condition_scale=1.0,
video_grid_size=(f_v, h, w),
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
video_x = get_current_chunk(video_x, dim=1)
video_x = gradient_checkpoint_forward(
video_dit.blocks[block_id],
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
video_x, video_context, video_t_mod, video_freqs
)
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
audio_x = get_current_chunk(audio_x, dim=1)
audio_x = gradient_checkpoint_forward(
audio_dit.blocks[block_id],
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
audio_x, audio_context, audio_t_mod, audio_freqs
)
audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)
video_x = get_current_chunk(video_x, dim=1)
for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):
video_x = gradient_checkpoint_forward(
video_dit.blocks[block_id],
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
video_x, video_context, video_t_mod, video_freqs
)
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
# Head
video_x = video_dit.head(video_x, video_t)
video_x = video_dit.unpatchify(video_x, (f_v, h, w))
audio_x = audio_dit.head(audio_x, audio_t)
audio_x = audio_dit.unpatchify(audio_x, (f_a,))
return video_x, audio_x

View File

@@ -0,0 +1,108 @@
import torch
import torchaudio
from torchcodec.decoders import AudioDecoder
from torchcodec.encoders import AudioEncoder
def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert audio to mono by averaging channels.
Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T].
"""
return audio_tensor.mean(dim=-2, keepdim=True)
def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert audio to stereo.
Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo.
"""
if audio_tensor.size(-2) == 1:
return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1)
return audio_tensor
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(dtype=waveform.dtype)
def read_audio_with_torchcodec(
path: str,
start_time: float = 0,
duration: float | None = None,
) -> tuple[torch.Tensor, int]:
"""
Read audio from file natively using torchcodec, with optional start time and duration.
Args:
path (str): The file path to the audio file.
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
Returns:
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
"""
decoder = AudioDecoder(path)
stop_seconds = None if duration is None else start_time + duration
waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data
return waveform, decoder.metadata.sample_rate
def read_audio(
path: str,
start_time: float = 0,
duration: float | None = None,
resample: bool = False,
resample_rate: int = 48000,
backend: str = "torchcodec",
) -> tuple[torch.Tensor, int]:
"""
Read audio from file, with optional start time, duration, and resampling.
Args:
path (str): The file path to the audio file.
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False.
resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000.
backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec".
Returns:
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
"""
if backend == "torchcodec":
waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration)
else:
raise ValueError(f"Unsupported audio backend: {backend}")
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
return waveform, sample_rate
def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"):
"""
Save audio tensor to file.
Args:
waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T].
sample_rate (int): The sample rate of the audio.
save_path (str): The file path to save the audio to.
backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec".
"""
if waveform.dim() == 3:
waveform = waveform[0]
if backend == "torchcodec":
encoder = AudioEncoder(waveform, sample_rate=sample_rate)
encoder.to_file(dest=save_path)
else:
raise ValueError(f"Unsupported audio backend: {backend}")

View File

@@ -0,0 +1,134 @@
import av
from fractions import Fraction
import torch
from PIL import Image
from tqdm import tqdm
from .audio import convert_to_stereo
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples.unsqueeze(0)
samples = convert_to_stereo(samples)
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
samples = samples.T
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac")
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream
def write_video_audio(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = None,
) -> None:
"""
Writes a sequence of images and an audio tensor to a video file.
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
and multiplex a PyTorch tensor as the audio stream into the output container.
Args:
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
The length of this list determines the total duration of the video based on the FPS.
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
The shape is typically (channels, samples). If no audio is required, pass None.
channels can be 1 or 2. 1 for mono, 2 for stereo.
output_path (str): The file path (including extension) where the output video will be saved.
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
If the audio tensor is provided and this is None, the function attempts to infer the rate
based on the audio tensor's length and the video duration.
Raises:
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
"""
duration = len(video) / fps
if audio_sample_rate is None:
audio_sample_rate = int(audio.shape[-1] / duration)
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -1,166 +1,7 @@
from fractions import Fraction
import torch
import torchaudio
import av
from tqdm import tqdm
from PIL import Image
import numpy as np
from io import BytesIO
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[0] == 1:
samples = samples.repeat(2, 1)
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
samples = samples.T
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac")
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream
def write_video_audio_ltx2(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = None,
) -> None:
"""
Writes a sequence of images and an audio tensor to a video file.
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
and multiplex a PyTorch tensor as the audio stream into the output container.
Args:
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
The length of this list determines the total duration of the video based on the FPS.
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
The shape is typically (channels, samples). If no audio is required, pass None.
channels can be 1 or 2. 1 for mono, 2 for stereo.
output_path (str): The file path (including extension) where the output video will be saved.
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
If the audio tensor is provided and this is None, the function attempts to infer the rate
based on the audio tensor's length and the video duration.
Raises:
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
"""
duration = len(video) / fps
if audio_sample_rate is None:
audio_sample_rate = int(audio.shape[-1] / duration)
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(dtype=waveform.dtype)
def read_audio_with_torchaudio(
path: str,
start_time: float = 0,
duration: float | None = None,
resample: bool = False,
resample_rate: int = 48000,
) -> tuple[torch.Tensor, int]:
waveform, sample_rate = torchaudio.load(path, channels_first=True)
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
start_frame = int(start_time * sample_rate)
if start_frame > waveform.shape[-1]:
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
end_frame = None if duration is None else int(duration * sample_rate + start_frame)
return waveform[..., start_frame:end_frame], sample_rate
from .audio_video import write_video_audio as write_video_audio_ltx2
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:

View File

@@ -1 +1 @@
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks

View File

@@ -143,4 +143,31 @@ def usp_attn_forward(self, x, freqs):
del q, k, v
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)
return self.o(x)
def get_current_chunk(x, dim=1):
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim)
ndims = len(chunks[0].shape)
pad_list = [0] * (2 * ndims)
pad_end_index = 2 * (ndims - 1 - dim) + 1
max_size = chunks[0].size(dim)
chunks = [
torch.nn.functional.pad(
chunk,
tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]),
value=0
)
for chunk in chunks
]
x = chunks[get_sequence_parallel_rank()]
return x
def gather_all_chunks(x, seq_len=None, dim=1):
x = get_sp_group().all_gather(x, dim=dim)
if seq_len is not None:
slices = [slice(None)] * x.ndim
slices[dim] = slice(0, seq_len)
x = x[tuple(slices)]
return x

View File

@@ -137,6 +137,8 @@ graph LR;
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)

View File

@@ -138,6 +138,8 @@ graph LR;
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
* FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)

View File

@@ -1,6 +1,7 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
from diffsynth.utils.data.audio import read_audio
from modelscope import dataset_snapshot_download
vram_config = {
@@ -42,7 +43,7 @@ negative_prompt = (
)
height, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24
duration = num_frames / frame_rate
audio, audio_sample_rate = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
audio, audio_sample_rate = read_audio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,

View File

@@ -1,6 +1,7 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
from diffsynth.utils.data.audio import read_audio
from modelscope import dataset_snapshot_download
from diffsynth.utils.data import VideoData
@@ -47,7 +48,7 @@ path = "data/example_video_dataset/ltx2/video2.mp4"
video = VideoData(path, height=height, width=width).raw_data()[:num_frames]
assert len(video) == num_frames, f"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument."
duration = num_frames / frame_rate
audio, audio_sample_rate = read_audio_with_torchaudio(path)
audio, audio_sample_rate = read_audio(path)
# Regenerate the video within time regions. You can specify different time regions for video frames and audio retake.
# retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s].

View File

@@ -1,6 +1,7 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
from diffsynth.utils.data.audio import read_audio
from modelscope import dataset_snapshot_download
vram_config = {
@@ -43,7 +44,7 @@ negative_prompt = (
)
height, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24
duration = num_frames / frame_rate
audio, audio_sample_rate = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
audio, audio_sample_rate = read_audio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,

View File

@@ -1,6 +1,7 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
from diffsynth.utils.data.audio import read_audio
from modelscope import dataset_snapshot_download
from diffsynth.utils.data import VideoData
@@ -48,7 +49,7 @@ path = "data/example_video_dataset/ltx2/video2.mp4"
video = VideoData(path, height=height, width=width).raw_data()[:num_frames]
assert len(video) == num_frames, f"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument."
duration = num_frames / frame_rate
audio, audio_sample_rate = read_audio_with_torchaudio(path)
audio, audio_sample_rate = read_audio(path)
# Regenerate the video within time regions. You can specify different time regions for video frames and audio retake.
# retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s].

View File

@@ -0,0 +1,55 @@
import torch
from PIL import Image
from diffsynth.utils.data.audio_video import write_video_audio
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
import torch.distributed as dist
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
use_usp=True,
model_configs=[
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
)
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
height, width, num_frames = 352, 640, 121
frame_rate=24
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
if dist.get_rank() == 0:
write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,52 @@
import torch
from PIL import Image
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
from diffsynth.utils.data.audio_video import write_video_audio
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
)
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
height, width, num_frames = 352, 640, 121
frame_rate = 24
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,52 @@
import torch
from PIL import Image
from diffsynth.utils.data.audio_video import write_video_audio
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
)
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
height, width, num_frames = 720, 1280, 121
frame_rate = 24
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,53 @@
import torch
from PIL import Image
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
from diffsynth.utils.data.audio_video import write_video_audio
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
height, width, num_frames = 352, 640, 121
frame_rate = 24
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,53 @@
import torch
from PIL import Image
from diffsynth.utils.data.audio_video import write_video_audio
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
height, width, num_frames = 720, 1280, 121
frame_rate = 24
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,39 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio,input_image" \
--height 352 \
--width 640 \
--num_frames 121 \
--dataset_repeat 100 \
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.video_dit." \
--output_path "./models/train/MOVA-360p-I2AV_high_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0 \
--use_gradient_checkpointing
# boundary corresponds to timesteps [900, 1000]
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio,input_image" \
--height 352 \
--width 640 \
--num_frames 121 \
--dataset_repeat 100 \
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.video_dit." \
--output_path "./models/train/MOVA-360p-I2AV_low_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358 \
--use_gradient_checkpointing
# boundary corresponds to timesteps [0, 900)

View File

@@ -0,0 +1,39 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio,input_image" \
--height 720 \
--width 1280 \
--num_frames 121 \
--dataset_repeat 100 \
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.video_dit." \
--output_path "./models/train/MOVA-720p-I2AV_high_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0 \
--use_gradient_checkpointing
# boundary corresponds to timesteps [900, 1000]
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio,input_image" \
--height 720 \
--width 1280 \
--num_frames 121 \
--dataset_repeat 100 \
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.video_dit." \
--output_path "./models/train/MOVA-720p-I2AV_low_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358 \
--use_gradient_checkpointing
# boundary corresponds to timesteps [0, 900)

View File

@@ -0,0 +1,43 @@
accelerate launch examples/mova/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio,input_image" \
--height 352 \
--width 640 \
--num_frames 121 \
--dataset_repeat 100 \
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.video_dit." \
--output_path "./models/train/MOVA-360p-I2AV_high_noise_lora" \
--lora_base_model "video_dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0 \
--use_gradient_checkpointing
# boundary corresponds to timesteps [900, 1000]
# accelerate launch examples/mova/model_training/train.py \
# --dataset_base_path data/example_video_dataset/ltx2 \
# --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
# --data_file_keys "video,input_audio" \
# --extra_inputs "input_audio,input_image" \
# --height 352 \
# --width 640 \
# --num_frames 121 \
# --dataset_repeat 100 \
# --model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.video_dit." \
# --output_path "./models/train/MOVA-360p-I2AV_low_noise_lora" \
# --lora_base_model "video_dit" \
# --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
# --lora_rank 32 \
# --max_timestep_boundary 1 \
# --min_timestep_boundary 0.358 \
# --use_gradient_checkpointing
# boundary corresponds to timesteps [0, 900)

View File

@@ -0,0 +1,43 @@
accelerate launch examples/mova/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio,input_image" \
--height 720 \
--width 1280 \
--num_frames 121 \
--dataset_repeat 100 \
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.video_dit." \
--output_path "./models/train/MOVA-720p-I2AV_high_noise_lora" \
--lora_base_model "video_dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0 \
--use_gradient_checkpointing
# boundary corresponds to timesteps [900, 1000]
accelerate launch examples/mova/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio,input_image" \
--height 720 \
--width 1280 \
--num_frames 121 \
--dataset_repeat 100 \
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.video_dit." \
--output_path "./models/train/MOVA-720p-I2AV_low_noise_lora" \
--lora_base_model "video_dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358 \
--use_gradient_checkpointing
# boundary corresponds to timesteps [0, 900)

View File

@@ -0,0 +1,193 @@
import torch, os, argparse, accelerate, warnings
from diffsynth.core import UnifiedDataset
from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class MOVATrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
max_timestep_boundary=1.0,
min_timestep_boundary=0.0,
):
super().__init__()
# Warning
if not use_gradient_checkpointing:
warnings.warn("Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.")
use_gradient_checkpointing = True
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = MovaAudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(
task, self.pipe, trainable_models, lora_base_model,
remove_unnecessary_params=True,
force_remove_params_shared=("audio_latents", "video_latents"),
force_remove_params_nega=("audio_context", "video_context")
)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),
}
self.max_timestep_boundary = max_timestep_boundary
self.min_timestep_boundary = min_timestep_boundary
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
for extra_input in extra_inputs:
if extra_input == "input_image":
inputs_shared["input_image"] = data["video"][0]
else:
inputs_shared[extra_input] = data[extra_input]
return inputs_shared
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_video": data["video"],
"height": data["video"][0].size[1],
"width": data["video"][0].size[0],
"num_frames": len(data["video"]),
"frame_rate": data.get("frame_rate", 24),
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"max_timestep_boundary": self.max_timestep_boundary,
"min_timestep_boundary": self.min_timestep_boundary,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def ltx2_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser = add_general_config(parser)
parser = add_video_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--frame_rate", type=float, default=24, help="Frame rate of the training videos. Mova is trained with a frame rate of 24, so it's recommended to use the same frame rate.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser
if __name__ == "__main__":
parser = ltx2_parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
model = MOVATrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
max_timestep_boundary=args.max_timestep_boundary,
min_timestep_boundary=args.min_timestep_boundary,
)
video_processor = UnifiedDataset.default_video_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=model.pipe.height_division_factor,
width_division_factor=model.pipe.width_division_factor,
num_frames=args.num_frames,
time_division_factor=model.pipe.time_division_factor,
time_division_remainder=model.pipe.time_division_remainder,
frame_rate=args.frame_rate,
fix_frame_rate=True,
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=video_processor,
special_operator_map={
"input_audio":
ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(
num_frames=args.num_frames,
time_division_factor=model.pipe.time_division_factor,
time_division_remainder=model.pipe.time_division_remainder,
frame_rate=args.frame_rate,
),
"in_context_videos":
RouteByType(operator_map=[
(str, video_processor),
(list, SequencialProcess(video_processor)),
]),
},
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"direct_distill:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,53 @@
import torch
from PIL import Image
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
from diffsynth.utils.data.audio_video import write_video_audio
from diffsynth.utils.data import VideoData
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(path="./models/train/MOVA-360p-I2AV_high_noise_full/epoch-4.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
)
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "A beautiful sunset over the ocean."
height, width, num_frames = 352, 640, 121
frame_rate = 24
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-360p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,54 @@
import torch
from PIL import Image
from diffsynth.utils.data.audio_video import write_video_audio
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
from diffsynth.utils.data import VideoData
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(path="./models/train/MOVA-720p-I2AV_high_noise_full/epoch-4.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.video_dit, "models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors")
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "A beautiful sunset over the ocean."
height, width, num_frames = 720, 1280, 121
frame_rate = 24
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-720p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,54 @@
import torch
from PIL import Image
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
from diffsynth.utils.data.audio_video import write_video_audio
from diffsynth.utils.data import VideoData
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.video_dit, "models/train/MOVA-360p-I2AV_high_noise_lora/epoch-4.safetensors")
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "A beautiful sunset over the ocean."
height, width, num_frames = 352, 640, 121
frame_rate = 24
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-360p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -0,0 +1,54 @@
import torch
from PIL import Image
from diffsynth.utils.data.audio_video import write_video_audio
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
from diffsynth.utils.data import VideoData
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = MovaAudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.video_dit, "models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors")
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指"
)
prompt = "A beautiful sunset over the ocean."
height, width, num_frames = 720, 1280, 121
frame_rate = 24
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
# Image-to-video
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
input_image=input_image,
num_inference_steps=50,
seed=0,
tiled=True,
frame_rate=frame_rate,
)
write_video_audio(video, audio, "MOVA-720p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)

View File

@@ -12,7 +12,6 @@ requires-python = ">=3.10.1"
dependencies = [
"torch>=2.0.0",
"torchvision",
"torchaudio",
"transformers",
"imageio",
"imageio[ffmpeg]",
@@ -48,6 +47,10 @@ npu = [
"torch-npu==2.7.1",
"torchvision==0.22.1+cpu"
]
audio = [
"torchaudio",
"torchcodec"
]
[tool.setuptools]
include-package-data = true