mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 13:58:15 +00:00
Mova (#1337)
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
This commit is contained in:
@@ -32,8 +32,9 @@ We believe that a well-developed open-source code framework can lower the thresh
|
||||
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
||||
|
||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||
- **January 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
||||
|
||||
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/zh/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
||||
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
||||
|
||||
- **March 3, 2026**: We released the [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) model, which is an updated version of Qwen-Image-Layered-Control. In addition to the originally supported text-guided functionality, it adds brush-controlled layer separation capabilities.
|
||||
|
||||
@@ -867,6 +868,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
|
||||
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -32,8 +32,10 @@ DiffSynth 目前包括两个开源项目:
|
||||
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
|
||||
|
||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||
- **2026年1月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
||||
|
||||
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
|
||||
|
||||
- **2026年3月3日** 我们发布了 [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) 模型,这是 Qwen-Image-Layered-Control 的更新版本。除了原本就支持的文本引导功能,新增了画笔控制的图层拆分能力。
|
||||
|
||||
- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。
|
||||
@@ -866,6 +868,8 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
|
||||
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -848,4 +848,26 @@ anima_series = [
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
|
||||
}
|
||||
]
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series
|
||||
|
||||
mova_series = [
|
||||
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors")
|
||||
{
|
||||
"model_hash": "8c57e12790e2c45a64817e0ce28cde2f",
|
||||
"model_name": "mova_audio_dit",
|
||||
"model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||
},
|
||||
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors")
|
||||
{
|
||||
"model_hash": "418517fb2b4e919d2cac8f314fcf82ac",
|
||||
"model_name": "mova_audio_vae",
|
||||
"model_class": "diffsynth.models.mova_audio_vae.DacVAE",
|
||||
},
|
||||
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors")
|
||||
{
|
||||
"model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb",
|
||||
"model_name": "mova_dual_tower_bridge",
|
||||
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
||||
},
|
||||
]
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
|
||||
|
||||
@@ -249,6 +249,24 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.mova_audio_dit.MovaAudioDit": {
|
||||
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.mova_audio_vae.DacVAE": {
|
||||
"diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
def QwenImageTextEncoder_Module_Map_Updater():
|
||||
|
||||
@@ -152,13 +152,6 @@ class BasePipeline(torch.nn.Module):
|
||||
# remove batch dim
|
||||
if audio_output.ndim == 3:
|
||||
audio_output = audio_output.squeeze(0)
|
||||
# Transform to stereo
|
||||
if audio_output.shape[0] == 1:
|
||||
audio_output = audio_output.repeat(2, 1)
|
||||
elif audio_output.shape[0] == 2:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("The output audio should be [C, T] or [1, C, T] or [2, C, T].")
|
||||
return audio_output.float()
|
||||
|
||||
def load_models_to_device(self, model_names):
|
||||
|
||||
57
diffsynth/models/mova_audio_dit.py
Normal file
57
diffsynth/models/mova_audio_dit.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d
|
||||
from einops import rearrange
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):
|
||||
f_freqs_cis = precompute_freqs_cis(dim, end, theta)
|
||||
return f_freqs_cis.chunk(3, dim=-1)
|
||||
|
||||
class MovaAudioDit(WanModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12)
|
||||
self.freqs = precompute_freqs_cis_1d(head_dim)
|
||||
self.patch_embedding = nn.Conv1d(
|
||||
kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1]
|
||||
)
|
||||
|
||||
def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):
|
||||
self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
x, (f, ) = self.patchify(x)
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, -1).expand(f, -1),
|
||||
self.freqs[1][:f].view(f, -1).expand(f, -1),
|
||||
self.freqs[2][:f].view(f, -1).expand(f, -1),
|
||||
], dim=-1).reshape(f, 1, -1).to(x.device)
|
||||
|
||||
for block in self.blocks:
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, freqs,
|
||||
)
|
||||
x = self.head(x, t)
|
||||
x = self.unpatchify(x, (f, ))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x, 'b f (p c) -> b c (f p)',
|
||||
f=grid_size[0],
|
||||
p=self.patch_size[0]
|
||||
)
|
||||
796
diffsynth/models/mova_audio_vae.py
Normal file
796
diffsynth/models/mova_audio_vae.py
Normal file
@@ -0,0 +1,796 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
"""
|
||||
Implementation of VQ similar to Karpathy's repo:
|
||||
https://github.com/karpathy/deep-vector-quantization
|
||||
Additionally uses following tricks from Improved VQGAN
|
||||
(https://arxiv.org/pdf/2110.04627.pdf):
|
||||
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
||||
for improved codebook usage
|
||||
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
||||
improves training stability
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
|
||||
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
||||
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
||||
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
||||
|
||||
def forward(self, z):
|
||||
"""Quantized the input tensor using a fixed codebook and returns
|
||||
the corresponding codebook vectors
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
Tensor[B x T]
|
||||
Codebook indices (quantized discrete representation of input)
|
||||
Tensor[B x D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"""
|
||||
|
||||
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
||||
z_e = self.in_proj(z) # z_e : (B x D x T)
|
||||
z_q, indices = self.decode_latents(z_e)
|
||||
|
||||
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
||||
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
||||
|
||||
z_q = (
|
||||
z_e + (z_q - z_e).detach()
|
||||
) # noop in forward pass, straight-through gradient estimator in backward pass
|
||||
|
||||
z_q = self.out_proj(z_q)
|
||||
|
||||
return z_q, commitment_loss, codebook_loss, indices, z_e
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.codebook.weight)
|
||||
|
||||
def decode_code(self, embed_id):
|
||||
return self.embed_code(embed_id).transpose(1, 2)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
encodings = rearrange(latents, "b d t -> (b t) d")
|
||||
codebook = self.codebook.weight # codebook: (N x D)
|
||||
|
||||
# L2 normalize encodings and codebook (ViT-VQGAN)
|
||||
encodings = F.normalize(encodings)
|
||||
codebook = F.normalize(codebook)
|
||||
|
||||
# Compute euclidean distance with codebook
|
||||
dist = (
|
||||
encodings.pow(2).sum(1, keepdim=True)
|
||||
- 2 * encodings @ codebook.t()
|
||||
+ codebook.pow(2).sum(1, keepdim=True).t()
|
||||
)
|
||||
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
||||
z_q = self.decode_code(indices)
|
||||
return z_q, indices
|
||||
|
||||
|
||||
class ResidualVectorQuantize(nn.Module):
|
||||
"""
|
||||
Introduced in SoundStream: An end2end neural audio codec
|
||||
https://arxiv.org/abs/2107.03312
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 512,
|
||||
n_codebooks: int = 9,
|
||||
codebook_size: int = 1024,
|
||||
codebook_dim: Union[int, list] = 8,
|
||||
quantizer_dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(codebook_dim, int):
|
||||
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
||||
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_dim = codebook_dim
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.quantizers = nn.ModuleList(
|
||||
[
|
||||
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
||||
for i in range(n_codebooks)
|
||||
]
|
||||
)
|
||||
self.quantizer_dropout = quantizer_dropout
|
||||
|
||||
def forward(self, z, n_quantizers: int = None):
|
||||
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
||||
the corresponding codebook vectors
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
n_quantizers : int, optional
|
||||
No. of quantizers to use
|
||||
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
||||
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
||||
when in training mode, and a random number of quantizers is used.
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"""
|
||||
z_q = 0
|
||||
residual = z
|
||||
commitment_loss = 0
|
||||
codebook_loss = 0
|
||||
|
||||
codebook_indices = []
|
||||
latents = []
|
||||
|
||||
if n_quantizers is None:
|
||||
n_quantizers = self.n_codebooks
|
||||
if self.training:
|
||||
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
||||
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
||||
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
||||
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
||||
n_quantizers = n_quantizers.to(z.device)
|
||||
|
||||
for i, quantizer in enumerate(self.quantizers):
|
||||
if self.training is False and i >= n_quantizers:
|
||||
break
|
||||
|
||||
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
||||
residual
|
||||
)
|
||||
|
||||
# Create mask to apply quantizer dropout
|
||||
mask = (
|
||||
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
||||
)
|
||||
z_q = z_q + z_q_i * mask[:, None, None]
|
||||
residual = residual - z_q_i
|
||||
|
||||
# Sum losses
|
||||
commitment_loss += (commitment_loss_i * mask).mean()
|
||||
codebook_loss += (codebook_loss_i * mask).mean()
|
||||
|
||||
codebook_indices.append(indices_i)
|
||||
latents.append(z_e_i)
|
||||
|
||||
codes = torch.stack(codebook_indices, dim=1)
|
||||
latents = torch.cat(latents, dim=1)
|
||||
|
||||
return z_q, codes, latents, commitment_loss, codebook_loss
|
||||
|
||||
def from_codes(self, codes: torch.Tensor):
|
||||
"""Given the quantized codes, reconstruct the continuous representation
|
||||
Parameters
|
||||
----------
|
||||
codes : Tensor[B x N x T]
|
||||
Quantized discrete representation of input
|
||||
Returns
|
||||
-------
|
||||
Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"""
|
||||
z_q = 0.0
|
||||
z_p = []
|
||||
n_codebooks = codes.shape[1]
|
||||
for i in range(n_codebooks):
|
||||
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
||||
z_p.append(z_p_i)
|
||||
|
||||
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||
z_q = z_q + z_q_i
|
||||
return z_q, torch.cat(z_p, dim=1), codes
|
||||
|
||||
def from_latents(self, latents: torch.Tensor):
|
||||
"""Given the unquantized latents, reconstruct the
|
||||
continuous representation after quantization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
latents : Tensor[B x N x T]
|
||||
Continuous representation of input after projection
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor[B x D x T]
|
||||
Quantized representation of full-projected space
|
||||
Tensor[B x D x T]
|
||||
Quantized representation of latent space
|
||||
"""
|
||||
z_q = 0
|
||||
z_p = []
|
||||
codes = []
|
||||
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
||||
|
||||
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
||||
0
|
||||
]
|
||||
for i in range(n_codebooks):
|
||||
j, k = dims[i], dims[i + 1]
|
||||
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
||||
z_p.append(z_p_i)
|
||||
codes.append(codes_i)
|
||||
|
||||
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||
z_q = z_q + z_q_i
|
||||
|
||||
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.mean(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2],
|
||||
)
|
||||
else:
|
||||
return 0.5 * torch.mean(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1, 2]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims,
|
||||
)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, dim: int = 16, stride: int = 1):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
ResidualUnit(dim // 2, dilation=1),
|
||||
ResidualUnit(dim // 2, dilation=3),
|
||||
ResidualUnit(dim // 2, dilation=9),
|
||||
Snake1d(dim // 2),
|
||||
WNConv1d(
|
||||
dim // 2,
|
||||
dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
d_latent: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
self.block += [EncoderBlock(d_model, stride=stride)]
|
||||
|
||||
# Create last convolution
|
||||
self.block += [
|
||||
Snake1d(d_model),
|
||||
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
||||
]
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(input_dim),
|
||||
WNConvTranspose1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
ResidualUnit(output_dim, dilation=1),
|
||||
ResidualUnit(output_dim, dilation=3),
|
||||
ResidualUnit(output_dim, dilation=9),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
d_out: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class DacVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 3, 4, 5, 8],
|
||||
latent_dim: int = 128,
|
||||
decoder_dim: int = 2048,
|
||||
decoder_rates: List[int] = [8, 5, 4, 3, 2],
|
||||
n_codebooks: int = 9,
|
||||
codebook_size: int = 1024,
|
||||
codebook_dim: Union[int, list] = 8,
|
||||
quantizer_dropout: bool = False,
|
||||
sample_rate: int = 48000,
|
||||
continuous: bool = True,
|
||||
use_weight_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.sample_rate = sample_rate
|
||||
self.continuous = continuous
|
||||
self.use_weight_norm = use_weight_norm
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
||||
|
||||
if not continuous:
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.quantizer = ResidualVectorQuantize(
|
||||
input_dim=latent_dim,
|
||||
n_codebooks=n_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
)
|
||||
else:
|
||||
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
|
||||
|
||||
self.decoder = Decoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.apply(init_weights)
|
||||
|
||||
self.delay = self.get_delay()
|
||||
|
||||
if not self.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
|
||||
def get_delay(self):
|
||||
# Any number works here, delay is invariant to input length
|
||||
l_out = self.get_output_length(0)
|
||||
L = l_out
|
||||
|
||||
layers = []
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||
layers.append(layer)
|
||||
|
||||
for layer in reversed(layers):
|
||||
d = layer.dilation[0]
|
||||
k = layer.kernel_size[0]
|
||||
s = layer.stride[0]
|
||||
|
||||
if isinstance(layer, nn.ConvTranspose1d):
|
||||
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||
elif isinstance(layer, nn.Conv1d):
|
||||
L = (L - 1) * s + d * (k - 1) + 1
|
||||
|
||||
L = math.ceil(L)
|
||||
|
||||
l_in = L
|
||||
|
||||
return (l_in - l_out) // 2
|
||||
|
||||
def get_output_length(self, input_length):
|
||||
L = input_length
|
||||
# Calculate output length
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||
d = layer.dilation[0]
|
||||
k = layer.kernel_size[0]
|
||||
s = layer.stride[0]
|
||||
|
||||
if isinstance(layer, nn.Conv1d):
|
||||
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||
elif isinstance(layer, nn.ConvTranspose1d):
|
||||
L = (L - 1) * s + d * (k - 1) + 1
|
||||
|
||||
L = math.floor(L)
|
||||
return L
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Get the dtype of the model parameters."""
|
||||
# Return the dtype of the first parameter found
|
||||
for param in self.parameters():
|
||||
return param.dtype
|
||||
return torch.float32 # fallback
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Get the device of the model parameters."""
|
||||
# Return the device of the first parameter found
|
||||
for param in self.parameters():
|
||||
return param.device
|
||||
return torch.device('cpu') # fallback
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def encode(
|
||||
self,
|
||||
audio_data: torch.Tensor,
|
||||
n_quantizers: int = None,
|
||||
):
|
||||
"""Encode given audio data and return quantized latent codes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"""
|
||||
z = self.encoder(audio_data) # [B x D x T]
|
||||
if not self.continuous:
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
||||
else:
|
||||
z = self.quant_conv(z) # [B x 2D x T]
|
||||
z = DiagonalGaussianDistribution(z)
|
||||
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
|
||||
|
||||
return z, codes, latents, commitment_loss, codebook_loss
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
if not self.continuous:
|
||||
audio = self.decoder(z)
|
||||
else:
|
||||
z = self.post_quant_conv(z)
|
||||
audio = self.decoder(z)
|
||||
|
||||
return audio
|
||||
|
||||
def forward(
|
||||
self,
|
||||
audio_data: torch.Tensor,
|
||||
sample_rate: int = None,
|
||||
n_quantizers: int = None,
|
||||
):
|
||||
"""Model forward pass
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
sample_rate : int, optional
|
||||
Sample rate of audio data in Hz, by default None
|
||||
If None, defaults to `self.sample_rate`
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None.
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
length = audio_data.shape[-1]
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
if not self.continuous:
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
||||
|
||||
x = self.decode(z)
|
||||
return {
|
||||
"audio": x[..., :length],
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
else:
|
||||
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
|
||||
z = posterior.sample()
|
||||
x = self.decode(z)
|
||||
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = kl_loss.mean()
|
||||
|
||||
return {
|
||||
"audio": x[..., :length],
|
||||
"z": z,
|
||||
"kl_loss": kl_loss,
|
||||
}
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""
|
||||
Remove weight_norm from all modules in the model.
|
||||
This fuses the weight_g and weight_v parameters into a single weight parameter.
|
||||
Should be called before inference for better performance.
|
||||
Returns:
|
||||
self: The model with weight_norm removed
|
||||
"""
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
num_removed = 0
|
||||
for name, module in list(self.named_modules()):
|
||||
if hasattr(module, "_forward_pre_hooks"):
|
||||
for hook_id, hook in list(module._forward_pre_hooks.items()):
|
||||
if "WeightNorm" in str(type(hook)):
|
||||
try:
|
||||
remove_weight_norm(module)
|
||||
num_removed += 1
|
||||
# print(f"Removed weight_norm from: {name}")
|
||||
except ValueError as e:
|
||||
print(f"Failed to remove weight_norm from {name}: {e}")
|
||||
if num_removed > 0:
|
||||
# print(f"Successfully removed weight_norm from {num_removed} modules")
|
||||
self.use_weight_norm = False
|
||||
else:
|
||||
print("No weight_norm found in the model")
|
||||
return self
|
||||
595
diffsynth/models/mova_dual_tower_bridge.py
Normal file
595
diffsynth/models/mova_dual_tower_bridge.py
Normal file
@@ -0,0 +1,595 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_dit import AttentionModule, RMSNorm
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, base: float, dim: int, device=None):
|
||||
super().__init__()
|
||||
self.base = base
|
||||
self.dim = dim
|
||||
self.attention_scaling = 1.0
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class PerFrameAttentionPooling(nn.Module):
|
||||
"""
|
||||
Per-frame multi-head attention pooling.
|
||||
|
||||
Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a
|
||||
single-query attention pooling over the H*W tokens for each time frame, producing
|
||||
[B, T, D].
|
||||
|
||||
Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.probe = nn.Parameter(torch.randn(1, 1, dim))
|
||||
nn.init.normal_(self.probe, std=0.02)
|
||||
|
||||
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
|
||||
self.layernorm = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, D], where L = T*H*W
|
||||
grid_size: (T, H, W)
|
||||
Returns:
|
||||
pooled: [B, T, D]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
T, H, W = grid_size
|
||||
assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}"
|
||||
assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}"
|
||||
|
||||
S = H * W
|
||||
# Re-arrange tokens grouped by frame.
|
||||
x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D]
|
||||
|
||||
# A learnable probe as the query (one query per frame).
|
||||
probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D]
|
||||
|
||||
# Attention pooling: query=probe, key/value=H*W tokens within the frame.
|
||||
pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D]
|
||||
pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D]
|
||||
|
||||
# Restore to [B, T, D].
|
||||
pooled = pooled_bt_d.view(B, T, D)
|
||||
pooled = self.layernorm(pooled)
|
||||
return pooled
|
||||
|
||||
|
||||
class CrossModalInteractionController:
|
||||
"""
|
||||
Strategy class that controls interactions between two towers.
|
||||
Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers).
|
||||
"""
|
||||
|
||||
def __init__(self, visual_layers: int = 30, audio_layers: int = 30):
|
||||
self.visual_layers = visual_layers
|
||||
self.audio_layers = audio_layers
|
||||
self.min_layers = min(visual_layers, audio_layers)
|
||||
|
||||
def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]:
|
||||
"""
|
||||
Get interaction layer mappings.
|
||||
|
||||
Args:
|
||||
strategy: interaction strategy
|
||||
- "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry
|
||||
- "distributed": distributed interactions across the network
|
||||
- "progressive": dense shallow interactions, sparse deeper interactions
|
||||
- "custom": custom interaction layers
|
||||
|
||||
Returns:
|
||||
A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual).
|
||||
"""
|
||||
|
||||
if strategy == "shallow_focus":
|
||||
# Emphasize the first ~1/3 layers to avoid deep-layer asymmetry.
|
||||
num_interact = min(10, self.min_layers // 3)
|
||||
interact_layers = list(range(0, num_interact))
|
||||
|
||||
elif strategy == "distributed":
|
||||
# Distribute interactions across the network (every few layers).
|
||||
step = 3
|
||||
interact_layers = list(range(0, self.min_layers, step))
|
||||
|
||||
elif strategy == "progressive":
|
||||
# Progressive: dense shallow interactions, sparse deeper interactions.
|
||||
shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers.
|
||||
if self.min_layers > 8:
|
||||
deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards.
|
||||
interact_layers = shallow + deep
|
||||
else:
|
||||
interact_layers = shallow
|
||||
|
||||
elif strategy == "custom":
|
||||
# Custom strategy: adjust as needed.
|
||||
interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices.
|
||||
interact_layers = [i for i in interact_layers if i < self.min_layers]
|
||||
|
||||
elif strategy == "full":
|
||||
interact_layers = list(range(0, self.min_layers))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown interaction strategy: {strategy}")
|
||||
|
||||
# Build bidirectional mapping.
|
||||
mapping = {
|
||||
'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i
|
||||
'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i
|
||||
}
|
||||
|
||||
return mapping
|
||||
|
||||
def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool:
|
||||
"""
|
||||
Check whether a given layer should interact.
|
||||
|
||||
Args:
|
||||
layer_idx: current layer index
|
||||
direction: interaction direction ('v2a' or 'a2v')
|
||||
interaction_mapping: interaction mapping table
|
||||
|
||||
Returns:
|
||||
bool: whether to interact
|
||||
"""
|
||||
if direction not in interaction_mapping:
|
||||
return False
|
||||
|
||||
return any(src == layer_idx for src, _ in interaction_mapping[direction])
|
||||
|
||||
|
||||
class ConditionalCrossAttention(nn.Module):
|
||||
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.q_dim = dim
|
||||
self.kv_dim = kv_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.q_dim // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(kv_dim, dim)
|
||||
self.v = nn.Linear(kv_dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
|
||||
ctx = y
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(ctx))
|
||||
v = self.v(ctx)
|
||||
if x_freqs is not None:
|
||||
x_cos, x_sin = x_freqs
|
||||
B, L, _ = q.shape
|
||||
q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim)
|
||||
x_cos = x_cos.to(q_view.dtype).to(q_view.device)
|
||||
x_sin = x_sin.to(q_view.dtype).to(q_view.device)
|
||||
# Expect x_cos/x_sin shape: [B or 1, L, head_dim]
|
||||
q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2)
|
||||
q = rearrange(q_view, 'b l h d -> b l (h d)')
|
||||
if y_freqs is not None:
|
||||
y_cos, y_sin = y_freqs
|
||||
Bc, Lc, _ = k.shape
|
||||
k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim)
|
||||
y_cos = y_cos.to(k_view.dtype).to(k_view.device)
|
||||
y_sin = y_sin.to(k_view.dtype).to(k_view.device)
|
||||
# Expect y_cos/y_sin shape: [B or 1, L, head_dim]
|
||||
_, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2)
|
||||
k = rearrange(k_view, 'b l h d -> b l (h d)')
|
||||
x = self.attn(q, k, v)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
# from diffusers.models.attention import AdaLayerNorm
|
||||
class AdaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||
output_dim (`int`, *optional*):
|
||||
norm_elementwise_affine (`bool`, defaults to `False):
|
||||
norm_eps (`bool`, defaults to `False`):
|
||||
chunk_dim (`int`, defaults to `0`):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
output_dim: Optional[int] = None,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if self.emb is not None:
|
||||
temb = self.emb(timestep)
|
||||
|
||||
temb = self.linear(self.silu(temb))
|
||||
|
||||
if self.chunk_dim == 2:
|
||||
scale, shift = temb.chunk(2, dim=2)
|
||||
# print(f"{x.shape = }, {scale.shape = }, {shift.shape = }")
|
||||
elif self.chunk_dim == 1:
|
||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
else:
|
||||
scale, shift = temb.chunk(2, dim=0)
|
||||
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalCrossAttentionBlock(nn.Module):
|
||||
"""
|
||||
A thin wrapper around ConditionalCrossAttention.
|
||||
Applies LayerNorm to the conditioning input `y` before cross-attention.
|
||||
"""
|
||||
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False):
|
||||
super().__init__()
|
||||
self.y_norm = nn.LayerNorm(kv_dim, eps=eps)
|
||||
self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps)
|
||||
self.pooled_adaln = pooled_adaln
|
||||
if pooled_adaln:
|
||||
self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps)
|
||||
self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.pooled_adaln:
|
||||
assert video_grid_size is not None, "video_grid_size must not be None"
|
||||
pooled_y = self.per_frame_pooling(y, video_grid_size)
|
||||
# Interpolate pooled_y along its temporal dimension to match x's sequence length.
|
||||
if pooled_y.shape[1] != x.shape[1]:
|
||||
pooled_y = F.interpolate(
|
||||
pooled_y.permute(0, 2, 1), # [B, C, T]
|
||||
size=x.shape[1],
|
||||
mode='linear',
|
||||
align_corners=False,
|
||||
).permute(0, 2, 1) # [B, T, C]
|
||||
x = self.adaln(x, temb=pooled_y)
|
||||
y = self.y_norm(y)
|
||||
return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)
|
||||
|
||||
|
||||
class DualTowerConditionalBridge(nn.Module):
|
||||
"""
|
||||
Dual-tower conditional bridge.
|
||||
"""
|
||||
def __init__(self,
|
||||
visual_layers: int = 40,
|
||||
audio_layers: int = 30,
|
||||
visual_hidden_dim: int = 5120, # visual DiT hidden state dimension
|
||||
audio_hidden_dim: int = 1536, # audio DiT hidden state dimension
|
||||
audio_fps: float = 50.0,
|
||||
head_dim: int = 128, # attention head dimension
|
||||
interaction_strategy: str = "full",
|
||||
apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention
|
||||
apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment
|
||||
trainable_condition_scale: bool = False,
|
||||
pooled_adaln: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.visual_hidden_dim = visual_hidden_dim
|
||||
self.audio_hidden_dim = audio_hidden_dim
|
||||
self.audio_fps = audio_fps
|
||||
self.head_dim = head_dim
|
||||
self.apply_cross_rope = apply_cross_rope
|
||||
self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope
|
||||
self.trainable_condition_scale = trainable_condition_scale
|
||||
self.pooled_adaln = pooled_adaln
|
||||
if self.trainable_condition_scale:
|
||||
self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32))
|
||||
else:
|
||||
self.condition_scale = 1.0
|
||||
|
||||
self.controller = CrossModalInteractionController(visual_layers, audio_layers)
|
||||
self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy)
|
||||
|
||||
# Conditional cross-attention modules operating at the DiT hidden-state level.
|
||||
self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning
|
||||
self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning
|
||||
|
||||
# Build conditioners for layers that should interact.
|
||||
# audio hidden states condition the visual DiT
|
||||
self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim)
|
||||
for v_layer, _ in self.interaction_mapping['a2v']:
|
||||
self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock(
|
||||
dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
|
||||
kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
|
||||
num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim
|
||||
pooled_adaln=False # a2v typically does not need pooled AdaLN
|
||||
)
|
||||
|
||||
# visual hidden states condition the audio DiT
|
||||
for a_layer, _ in self.interaction_mapping['v2a']:
|
||||
self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock(
|
||||
dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
|
||||
kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
|
||||
num_heads=audio_hidden_dim // head_dim, # safe head count derivation
|
||||
pooled_adaln=self.pooled_adaln
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def build_aligned_freqs(self,
|
||||
video_fps: float,
|
||||
grid_size: Tuple[int, int, int],
|
||||
audio_steps: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w),
|
||||
and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048).
|
||||
|
||||
Returns:
|
||||
visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim]
|
||||
audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim]
|
||||
"""
|
||||
f_v, h, w = grid_size
|
||||
L_v = f_v * h * w
|
||||
L_a = int(audio_steps)
|
||||
|
||||
device = device or next(self.parameters()).device
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
# Audio positions: 0,1,2,...,L_a-1 (audio as reference).
|
||||
audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)
|
||||
|
||||
# Video positions: align video frames to audio-step units.
|
||||
# FIXME(dhyu): hard-coded VAE temporal stride = 4
|
||||
if self.apply_first_frame_bias_in_rope:
|
||||
# Account for the "first frame lasts 1/video_fps" bias.
|
||||
video_effective_fps = float(video_fps) / 4.0
|
||||
if f_v > 0:
|
||||
t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)
|
||||
if f_v > 1:
|
||||
t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps)
|
||||
else:
|
||||
t_starts = torch.zeros((0,), device=device, dtype=torch.float32)
|
||||
# Convert to audio-step units.
|
||||
video_pos_per_frame = t_starts * float(self.audio_fps)
|
||||
else:
|
||||
# No first-frame bias: uniform alignment.
|
||||
scale = float(self.audio_fps) / float(video_fps / 4.0)
|
||||
video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale
|
||||
# Flatten to f*h*w; tokens within the same frame share the same time position.
|
||||
video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)
|
||||
|
||||
# print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}")
|
||||
# print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}")
|
||||
|
||||
# Build dummy x to produce cos/sin, dim=head_dim.
|
||||
dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype)
|
||||
dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype)
|
||||
|
||||
cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos)
|
||||
cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos)
|
||||
|
||||
return (cos_v, sin_v), (cos_a, sin_a)
|
||||
|
||||
def should_interact(self, layer_idx: int, direction: str) -> bool:
|
||||
return self.controller.should_interact(layer_idx, direction, self.interaction_mapping)
|
||||
|
||||
def apply_conditional_control(
|
||||
self,
|
||||
layer_idx: int,
|
||||
direction: str,
|
||||
primary_hidden_states: torch.Tensor,
|
||||
condition_hidden_states: torch.Tensor,
|
||||
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
condition_scale: Optional[float] = None,
|
||||
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||
use_gradient_checkpointing: Optional[bool] = False,
|
||||
use_gradient_checkpointing_offload: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply conditional control (at the DiT hidden-state level).
|
||||
|
||||
Args:
|
||||
layer_idx: current layer index
|
||||
direction: conditioning direction
|
||||
- 'a2v': audio hidden states -> visual DiT
|
||||
- 'v2a': visual hidden states -> audio DiT
|
||||
primary_hidden_states: primary DiT hidden states [B, L, hidden_dim]
|
||||
condition_hidden_states: condition DiT hidden states [B, L, hidden_dim]
|
||||
condition_scale: conditioning strength (similar to CFG scale)
|
||||
|
||||
Returns:
|
||||
Conditioned primary DiT hidden states [B, L, hidden_dim]
|
||||
"""
|
||||
|
||||
if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping):
|
||||
return primary_hidden_states
|
||||
|
||||
if direction == 'a2v':
|
||||
# audio hidden states condition the visual DiT
|
||||
conditioner = self.audio_to_video_conditioners[str(layer_idx)]
|
||||
|
||||
elif direction == 'v2a':
|
||||
# visual hidden states condition the audio DiT
|
||||
conditioner = self.video_to_audio_conditioners[str(layer_idx)]
|
||||
else:
|
||||
raise ValueError(f"Invalid direction: {direction}")
|
||||
|
||||
conditioned_features = gradient_checkpoint_forward(
|
||||
conditioner,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x=primary_hidden_states,
|
||||
y=condition_hidden_states,
|
||||
x_freqs=x_freqs,
|
||||
y_freqs=y_freqs,
|
||||
video_grid_size=video_grid_size,
|
||||
)
|
||||
|
||||
if self.trainable_condition_scale and condition_scale is not None:
|
||||
print(
|
||||
"[WARN] This model has a trainable condition_scale, but an external "
|
||||
f"condition_scale={condition_scale} was provided. The trainable condition_scale "
|
||||
"will be ignored in favor of the external value."
|
||||
)
|
||||
|
||||
scale = condition_scale if condition_scale is not None else self.condition_scale
|
||||
|
||||
primary_hidden_states = primary_hidden_states + conditioned_features * scale
|
||||
|
||||
return primary_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer_idx: int,
|
||||
visual_hidden_states: torch.Tensor,
|
||||
audio_hidden_states: torch.Tensor,
|
||||
*,
|
||||
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
a2v_condition_scale: Optional[float] = None,
|
||||
v2a_condition_scale: Optional[float] = None,
|
||||
condition_scale: Optional[float] = None,
|
||||
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||
use_gradient_checkpointing: Optional[bool] = False,
|
||||
use_gradient_checkpointing_offload: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply bidirectional conditional control to both visual/audio towers.
|
||||
|
||||
Args:
|
||||
layer_idx: current layer index
|
||||
visual_hidden_states: visual DiT hidden states
|
||||
audio_hidden_states: audio DiT hidden states
|
||||
x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs.
|
||||
If provided, x_freqs is assumed to correspond to the primary tower and y_freqs
|
||||
to the conditioning tower.
|
||||
a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale)
|
||||
v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale)
|
||||
condition_scale: fallback conditioning strength when per-direction scale is None
|
||||
video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled
|
||||
|
||||
Returns:
|
||||
(visual_hidden_states, audio_hidden_states), both conditioned in their respective directions.
|
||||
"""
|
||||
|
||||
visual_conditioned = self.apply_conditional_control(
|
||||
layer_idx=layer_idx,
|
||||
direction="a2v",
|
||||
primary_hidden_states=visual_hidden_states,
|
||||
condition_hidden_states=audio_hidden_states,
|
||||
x_freqs=x_freqs,
|
||||
y_freqs=y_freqs,
|
||||
condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale,
|
||||
video_grid_size=video_grid_size,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
audio_conditioned = self.apply_conditional_control(
|
||||
layer_idx=layer_idx,
|
||||
direction="v2a",
|
||||
primary_hidden_states=audio_hidden_states,
|
||||
condition_hidden_states=visual_hidden_states,
|
||||
x_freqs=y_freqs,
|
||||
y_freqs=x_freqs,
|
||||
condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale,
|
||||
video_grid_size=video_grid_size,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
return visual_conditioned, audio_conditioned
|
||||
@@ -99,18 +99,30 @@ def rope_apply(x, freqs, num_heads):
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
def set_to_torch_norm(models):
|
||||
for model in models:
|
||||
for module in model.modules():
|
||||
if isinstance(module, RMSNorm):
|
||||
module.use_torch_norm = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.use_torch_norm = False
|
||||
self.normalized_shape = (dim,)
|
||||
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
if self.use_torch_norm:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
else:
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Voco
|
||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||
from ..utils.data.audio import convert_to_stereo
|
||||
|
||||
|
||||
class LTX2AudioVideoPipeline(BasePipeline):
|
||||
@@ -389,6 +390,7 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
return {"audio_latents": audio_noise}
|
||||
else:
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = convert_to_stereo(input_audio)
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
|
||||
audio_input_latents = pipe.audio_vae_encoder(input_audio)
|
||||
@@ -441,6 +443,7 @@ class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):
|
||||
return {}
|
||||
else:
|
||||
input_audio, sample_rate = retake_audio
|
||||
input_audio = convert_to_stereo(input_audio)
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
input_latents_audio = pipe.audio_vae_encoder(input_audio)
|
||||
|
||||
460
diffsynth/pipelines/mova_audio_video.py
Normal file
460
diffsynth/pipelines/mova_audio_video.py
Normal file
@@ -0,0 +1,460 @@
|
||||
import sys
|
||||
import torch, types
|
||||
from PIL import Image
|
||||
from typing import Optional, Union
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.mova_audio_dit import MovaAudioDit
|
||||
from ..models.mova_audio_vae import DacVAE
|
||||
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
|
||||
from ..utils.data.audio import convert_to_mono, resample_waveform
|
||||
|
||||
|
||||
class MovaAudioVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("Wan")
|
||||
self.tokenizer: HuggingfaceTokenizer = None
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
self.video_dit: WanModel = None # high noise model
|
||||
self.video_dit2: WanModel = None # low noise model
|
||||
self.audio_dit: MovaAudioDit = None
|
||||
self.dual_tower_bridge: DualTowerConditionalBridge = None
|
||||
self.video_vae: WanVideoVAE = None
|
||||
self.audio_vae: DacVAE = None
|
||||
|
||||
self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge")
|
||||
self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge")
|
||||
|
||||
self.units = [
|
||||
MovaAudioVideoUnit_ShapeChecker(),
|
||||
MovaAudioVideoUnit_NoiseInitializer(),
|
||||
MovaAudioVideoUnit_InputVideoEmbedder(),
|
||||
MovaAudioVideoUnit_InputAudioEmbedder(),
|
||||
MovaAudioVideoUnit_PromptEmbedder(),
|
||||
MovaAudioVideoUnit_ImageEmbedderVAE(),
|
||||
MovaAudioVideoUnit_UnifiedSequenceParallel(),
|
||||
]
|
||||
self.model_fn = model_fn_mova_audio_video
|
||||
|
||||
def enable_usp(self):
|
||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
|
||||
for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
use_usp: bool = False,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
if use_usp:
|
||||
from ..utils.xfuser import initialize_usp
|
||||
initialize_usp(device)
|
||||
import torch.distributed as dist
|
||||
from ..core.device.npu_compatible_device import get_device_name
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
device = get_device_name()
|
||||
# Initialize pipeline
|
||||
pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
|
||||
dit = model_pool.fetch_model("wan_video_dit", index=2)
|
||||
if isinstance(dit, list):
|
||||
pipe.video_dit, pipe.video_dit2 = dit
|
||||
else:
|
||||
pipe.video_dit = dit
|
||||
pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
|
||||
pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
|
||||
pipe.video_vae = model_pool.fetch_model("wan_video_vae")
|
||||
pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
|
||||
set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))
|
||||
|
||||
# Size division factor
|
||||
if pipe.video_vae is not None:
|
||||
pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2
|
||||
pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2
|
||||
|
||||
# Initialize tokenizer and processor
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')
|
||||
|
||||
# Unified Sequence Parallel
|
||||
if use_usp: pipe.enable_usp()
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
input_image: Optional[Image.Image] = None,
|
||||
# First-last-frame-to-video
|
||||
end_image: Optional[Image.Image] = None,
|
||||
# Video-to-video
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height: Optional[int] = 352,
|
||||
width: Optional[int] = 640,
|
||||
num_frames: Optional[int] = 81,
|
||||
frame_rate: Optional[int] = 24,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
# Boundary
|
||||
switch_DiT_boundary: Optional[float] = 0.9,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_image": input_image,
|
||||
"end_image": end_image,
|
||||
"denoising_strength": denoising_strength,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
||||
"cfg_scale": cfg_scale,
|
||||
"sigma_shift": sigma_shift,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
# Switch DiT if necessary
|
||||
if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2:
|
||||
self.load_models_to_device(self.in_iteration_models_2)
|
||||
models["video_dit"] = self.video_dit2
|
||||
# Timestep
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
# Scheduler
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae'])
|
||||
video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(["audio_vae"])
|
||||
audio = self.audio_vae.decode(inputs_shared["audio_latents"])
|
||||
audio = self.output_audio_format_check(audio)
|
||||
self.load_models_to_device([])
|
||||
return video, audio
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames"),
|
||||
output_params=("height", "width", "num_frames"),
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):
|
||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||
return {"height": height, "width": width, "num_frames": num_frames}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
|
||||
output_params=("video_noise", "audio_noise")
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):
|
||||
length = (num_frames - 1) // 4 + 1
|
||||
video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)
|
||||
video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)
|
||||
|
||||
audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1
|
||||
audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)
|
||||
audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)
|
||||
return {"video_noise": video_noise, "audio_noise": audio_noise}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("video_latents", "input_latents"),
|
||||
onload_model_names=("video_vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
|
||||
if input_video is None or not pipe.scheduler.training:
|
||||
return {"video_latents": video_noise}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"input_latents": input_latents}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_audio", "audio_noise"),
|
||||
output_params=("audio_latents", "audio_input_latents"),
|
||||
onload_model_names=("audio_vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
|
||||
if input_audio is None or not pipe.scheduler.training:
|
||||
return {"audio_latents": audio_noise}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = convert_to_mono(input_audio)
|
||||
input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
|
||||
input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
|
||||
z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
|
||||
return {"audio_input_latents": z.mode()}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("context",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):
|
||||
ids, mask = pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=512,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
ids = ids.to(pipe.device)
|
||||
mask = mask.to(pipe.device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_emb = pipe.text_encoder(ids, mask)
|
||||
for i, v in enumerate(seq_lens):
|
||||
prompt_emb[:, v:] = 0
|
||||
return prompt_emb
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_emb = self.encode_prompt(pipe, prompt)
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("y",),
|
||||
onload_model_names=("video_vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or not pipe.video_dit.require_vae_embedding:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"y": y}
|
||||
|
||||
|
||||
class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))
|
||||
|
||||
def process(self, pipe: MovaAudioVideoPipeline):
|
||||
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
|
||||
return {"use_unified_sequence_parallel": True}
|
||||
return {"use_unified_sequence_parallel": False}
|
||||
|
||||
|
||||
def model_fn_mova_audio_video(
|
||||
video_dit: WanModel,
|
||||
audio_dit: MovaAudioDit,
|
||||
dual_tower_bridge: DualTowerConditionalBridge,
|
||||
video_latents: torch.Tensor = None,
|
||||
audio_latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
frame_rate: Optional[int] = 24,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
video_x, audio_x = video_latents, audio_latents
|
||||
# First-Last Frame
|
||||
if y is not None:
|
||||
video_x = torch.cat([video_x, y], dim=1)
|
||||
|
||||
# Timestep
|
||||
video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))
|
||||
video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))
|
||||
audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))
|
||||
audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))
|
||||
|
||||
# Context
|
||||
video_context = video_dit.text_embedding(context)
|
||||
audio_context = audio_dit.text_embedding(context)
|
||||
|
||||
# Patchify
|
||||
video_x = video_dit.patch_embedding(video_x)
|
||||
f_v, h, w = video_x.shape[2:]
|
||||
video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
seq_len_video = video_x.shape[1]
|
||||
|
||||
audio_x = audio_dit.patch_embedding(audio_x)
|
||||
f_a = audio_x.shape[2]
|
||||
audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()
|
||||
seq_len_audio = audio_x.shape[1]
|
||||
|
||||
# Freqs
|
||||
video_freqs = torch.cat([
|
||||
video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),
|
||||
video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),
|
||||
video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)
|
||||
], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)
|
||||
audio_freqs = torch.cat([
|
||||
audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||
audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||
audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||
], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)
|
||||
|
||||
video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(
|
||||
video_fps=frame_rate,
|
||||
grid_size=(f_v, h, w),
|
||||
audio_steps=audio_x.shape[1],
|
||||
device=video_x.device,
|
||||
dtype=video_x.dtype,
|
||||
)
|
||||
# usp func
|
||||
if use_unified_sequence_parallel:
|
||||
from ..utils.xfuser import get_current_chunk, gather_all_chunks
|
||||
else:
|
||||
get_current_chunk = lambda x, dim=1: x
|
||||
gather_all_chunks = lambda x, seq_len, dim=1: x
|
||||
# Forward blocks
|
||||
for block_id in range(len(audio_dit.blocks)):
|
||||
if dual_tower_bridge.should_interact(block_id, "a2v"):
|
||||
video_x, audio_x = dual_tower_bridge(
|
||||
block_id,
|
||||
video_x,
|
||||
audio_x,
|
||||
x_freqs=video_rope,
|
||||
y_freqs=audio_rope,
|
||||
condition_scale=1.0,
|
||||
video_grid_size=(f_v, h, w),
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
video_x = get_current_chunk(video_x, dim=1)
|
||||
video_x = gradient_checkpoint_forward(
|
||||
video_dit.blocks[block_id],
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
video_x, video_context, video_t_mod, video_freqs
|
||||
)
|
||||
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
||||
audio_x = get_current_chunk(audio_x, dim=1)
|
||||
audio_x = gradient_checkpoint_forward(
|
||||
audio_dit.blocks[block_id],
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
audio_x, audio_context, audio_t_mod, audio_freqs
|
||||
)
|
||||
audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)
|
||||
|
||||
video_x = get_current_chunk(video_x, dim=1)
|
||||
for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):
|
||||
video_x = gradient_checkpoint_forward(
|
||||
video_dit.blocks[block_id],
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
video_x, video_context, video_t_mod, video_freqs
|
||||
)
|
||||
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
||||
|
||||
# Head
|
||||
video_x = video_dit.head(video_x, video_t)
|
||||
video_x = video_dit.unpatchify(video_x, (f_v, h, w))
|
||||
|
||||
audio_x = audio_dit.head(audio_x, audio_t)
|
||||
audio_x = audio_dit.unpatchify(audio_x, (f_a,))
|
||||
return video_x, audio_x
|
||||
108
diffsynth/utils/data/audio.py
Normal file
108
diffsynth/utils/data/audio.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from torchcodec.decoders import AudioDecoder
|
||||
from torchcodec.encoders import AudioEncoder
|
||||
|
||||
|
||||
def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert audio to mono by averaging channels.
|
||||
Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T].
|
||||
"""
|
||||
return audio_tensor.mean(dim=-2, keepdim=True)
|
||||
|
||||
|
||||
def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert audio to stereo.
|
||||
Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo.
|
||||
"""
|
||||
if audio_tensor.size(-2) == 1:
|
||||
return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1)
|
||||
return audio_tensor
|
||||
|
||||
|
||||
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(dtype=waveform.dtype)
|
||||
|
||||
|
||||
def read_audio_with_torchcodec(
|
||||
path: str,
|
||||
start_time: float = 0,
|
||||
duration: float | None = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Read audio from file natively using torchcodec, with optional start time and duration.
|
||||
|
||||
Args:
|
||||
path (str): The file path to the audio file.
|
||||
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
|
||||
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
|
||||
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
|
||||
"""
|
||||
decoder = AudioDecoder(path)
|
||||
stop_seconds = None if duration is None else start_time + duration
|
||||
waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data
|
||||
return waveform, decoder.metadata.sample_rate
|
||||
|
||||
|
||||
def read_audio(
|
||||
path: str,
|
||||
start_time: float = 0,
|
||||
duration: float | None = None,
|
||||
resample: bool = False,
|
||||
resample_rate: int = 48000,
|
||||
backend: str = "torchcodec",
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Read audio from file, with optional start time, duration, and resampling.
|
||||
|
||||
Args:
|
||||
path (str): The file path to the audio file.
|
||||
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
|
||||
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
|
||||
resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False.
|
||||
resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000.
|
||||
backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec".
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
|
||||
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration)
|
||||
else:
|
||||
raise ValueError(f"Unsupported audio backend: {backend}")
|
||||
|
||||
if resample:
|
||||
waveform = resample_waveform(waveform, sample_rate, resample_rate)
|
||||
sample_rate = resample_rate
|
||||
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"):
|
||||
"""
|
||||
Save audio tensor to file.
|
||||
|
||||
Args:
|
||||
waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T].
|
||||
sample_rate (int): The sample rate of the audio.
|
||||
save_path (str): The file path to save the audio to.
|
||||
backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec".
|
||||
"""
|
||||
if waveform.dim() == 3:
|
||||
waveform = waveform[0]
|
||||
|
||||
if backend == "torchcodec":
|
||||
encoder = AudioEncoder(waveform, sample_rate=sample_rate)
|
||||
encoder.to_file(dest=save_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported audio backend: {backend}")
|
||||
134
diffsynth/utils/data/audio_video.py
Normal file
134
diffsynth/utils/data/audio_video.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import av
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from .audio import convert_to_stereo
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples.unsqueeze(0)
|
||||
samples = convert_to_stereo(samples)
|
||||
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
|
||||
samples = samples.T
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac")
|
||||
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
|
||||
if supported_sample_rates:
|
||||
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||
if best_rate != audio_sample_rate:
|
||||
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||
else:
|
||||
best_rate = audio_sample_rate
|
||||
audio_stream.codec_context.sample_rate = best_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, best_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def write_video_audio(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Writes a sequence of images and an audio tensor to a video file.
|
||||
|
||||
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
|
||||
and multiplex a PyTorch tensor as the audio stream into the output container.
|
||||
|
||||
Args:
|
||||
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
|
||||
The length of this list determines the total duration of the video based on the FPS.
|
||||
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
|
||||
The shape is typically (channels, samples). If no audio is required, pass None.
|
||||
channels can be 1 or 2. 1 for mono, 2 for stereo.
|
||||
output_path (str): The file path (including extension) where the output video will be saved.
|
||||
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
|
||||
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
|
||||
If the audio tensor is provided and this is None, the function attempts to infer the rate
|
||||
based on the audio tensor's length and the video duration.
|
||||
Raises:
|
||||
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
|
||||
"""
|
||||
duration = len(video) / fps
|
||||
if audio_sample_rate is None:
|
||||
audio_sample_rate = int(audio.shape[-1] / duration)
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
@@ -1,166 +1,7 @@
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
import torchaudio
|
||||
import av
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
if samples.shape[0] == 1:
|
||||
samples = samples.repeat(2, 1)
|
||||
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
|
||||
samples = samples.T
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac")
|
||||
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
|
||||
if supported_sample_rates:
|
||||
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||
if best_rate != audio_sample_rate:
|
||||
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||
else:
|
||||
best_rate = audio_sample_rate
|
||||
audio_stream.codec_context.sample_rate = best_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, best_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Writes a sequence of images and an audio tensor to a video file.
|
||||
|
||||
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
|
||||
and multiplex a PyTorch tensor as the audio stream into the output container.
|
||||
|
||||
Args:
|
||||
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
|
||||
The length of this list determines the total duration of the video based on the FPS.
|
||||
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
|
||||
The shape is typically (channels, samples). If no audio is required, pass None.
|
||||
channels can be 1 or 2. 1 for mono, 2 for stereo.
|
||||
output_path (str): The file path (including extension) where the output video will be saved.
|
||||
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
|
||||
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
|
||||
If the audio tensor is provided and this is None, the function attempts to infer the rate
|
||||
based on the audio tensor's length and the video duration.
|
||||
Raises:
|
||||
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
|
||||
"""
|
||||
duration = len(video) / fps
|
||||
if audio_sample_rate is None:
|
||||
audio_sample_rate = int(audio.shape[-1] / duration)
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(dtype=waveform.dtype)
|
||||
|
||||
|
||||
def read_audio_with_torchaudio(
|
||||
path: str,
|
||||
start_time: float = 0,
|
||||
duration: float | None = None,
|
||||
resample: bool = False,
|
||||
resample_rate: int = 48000,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
waveform, sample_rate = torchaudio.load(path, channels_first=True)
|
||||
if resample:
|
||||
waveform = resample_waveform(waveform, sample_rate, resample_rate)
|
||||
sample_rate = resample_rate
|
||||
start_frame = int(start_time * sample_rate)
|
||||
if start_frame > waveform.shape[-1]:
|
||||
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
|
||||
end_frame = None if duration is None else int(duration * sample_rate + start_frame)
|
||||
return waveform[..., start_frame:end_frame], sample_rate
|
||||
from .audio_video import write_video_audio as write_video_audio_ltx2
|
||||
|
||||
|
||||
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||
|
||||
@@ -143,4 +143,31 @@ def usp_attn_forward(self, x, freqs):
|
||||
|
||||
del q, k, v
|
||||
getattr(torch, parse_device_type(x.device)).empty_cache()
|
||||
return self.o(x)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
def get_current_chunk(x, dim=1):
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim)
|
||||
ndims = len(chunks[0].shape)
|
||||
pad_list = [0] * (2 * ndims)
|
||||
pad_end_index = 2 * (ndims - 1 - dim) + 1
|
||||
max_size = chunks[0].size(dim)
|
||||
chunks = [
|
||||
torch.nn.functional.pad(
|
||||
chunk,
|
||||
tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]),
|
||||
value=0
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
return x
|
||||
|
||||
|
||||
def gather_all_chunks(x, seq_len=None, dim=1):
|
||||
x = get_sp_group().all_gather(x, dim=dim)
|
||||
if seq_len is not None:
|
||||
slices = [slice(None)] * x.ndim
|
||||
slices[dim] = slice(0, seq_len)
|
||||
x = x[tuple(slices)]
|
||||
return x
|
||||
|
||||
@@ -137,6 +137,8 @@ graph LR;
|
||||
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
|
||||
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
|
||||
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
|
||||
|
||||
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
|
||||
|
||||
@@ -138,6 +138,8 @@ graph LR;
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
|
||||
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
|
||||
|
||||
* FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
|
||||
* 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data.audio import read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
@@ -42,7 +43,7 @@ negative_prompt = (
|
||||
)
|
||||
height, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24
|
||||
duration = num_frames / frame_rate
|
||||
audio, audio_sample_rate = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
|
||||
audio, audio_sample_rate = read_audio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data.audio import read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
@@ -47,7 +48,7 @@ path = "data/example_video_dataset/ltx2/video2.mp4"
|
||||
video = VideoData(path, height=height, width=width).raw_data()[:num_frames]
|
||||
assert len(video) == num_frames, f"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument."
|
||||
duration = num_frames / frame_rate
|
||||
audio, audio_sample_rate = read_audio_with_torchaudio(path)
|
||||
audio, audio_sample_rate = read_audio(path)
|
||||
|
||||
# Regenerate the video within time regions. You can specify different time regions for video frames and audio retake.
|
||||
# retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s].
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data.audio import read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
@@ -43,7 +44,7 @@ negative_prompt = (
|
||||
)
|
||||
height, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24
|
||||
duration = num_frames / frame_rate
|
||||
audio, audio_sample_rate = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
|
||||
audio, audio_sample_rate = read_audio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration)
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data.audio import read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
@@ -48,7 +49,7 @@ path = "data/example_video_dataset/ltx2/video2.mp4"
|
||||
video = VideoData(path, height=height, width=width).raw_data()[:num_frames]
|
||||
assert len(video) == num_frames, f"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument."
|
||||
duration = num_frames / frame_rate
|
||||
audio, audio_sample_rate = read_audio_with_torchaudio(path)
|
||||
audio, audio_sample_rate = read_audio(path)
|
||||
|
||||
# Regenerate the video within time regions. You can specify different time regions for video frames and audio retake.
|
||||
# retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s].
|
||||
|
||||
55
examples/mova/acceleration/unified_sequence_parallel.py
Normal file
55
examples/mova/acceleration/unified_sequence_parallel.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
|
||||
import torch.distributed as dist
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
use_usp=True,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
|
||||
height, width, num_frames = 352, 640, 121
|
||||
frame_rate=24
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
52
examples/mova/model_inference/MOVA-360p-I2AV.py
Normal file
52
examples/mova/model_inference/MOVA-360p-I2AV.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
|
||||
height, width, num_frames = 352, 640, 121
|
||||
frame_rate = 24
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
52
examples/mova/model_inference/MOVA-720p-I2AV.py
Normal file
52
examples/mova/model_inference/MOVA-720p-I2AV.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
|
||||
height, width, num_frames = 720, 1280, 121
|
||||
frame_rate = 24
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
53
examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py
Normal file
53
examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
|
||||
height, width, num_frames = 352, 640, 121
|
||||
frame_rate = 24
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
53
examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py
Normal file
53
examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other."
|
||||
height, width, num_frames = 720, 1280, 121
|
||||
frame_rate = 24
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB")
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
39
examples/mova/model_training/full/MOVA-360P-I2AV.sh
Normal file
39
examples/mova/model_training/full/MOVA-360P-I2AV.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 352 \
|
||||
--width 640 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
--output_path "./models/train/MOVA-360p-I2AV_high_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0 \
|
||||
--use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 352 \
|
||||
--width 640 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
--output_path "./models/train/MOVA-360p-I2AV_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358 \
|
||||
--use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [0, 900)
|
||||
39
examples/mova/model_training/full/MOVA-720P-I2AV.sh
Normal file
39
examples/mova/model_training/full/MOVA-720P-I2AV.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 720 \
|
||||
--width 1280 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
--output_path "./models/train/MOVA-720p-I2AV_high_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0 \
|
||||
--use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 720 \
|
||||
--width 1280 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
--output_path "./models/train/MOVA-720p-I2AV_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358 \
|
||||
--use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [0, 900)
|
||||
43
examples/mova/model_training/lora/MOVA-360P-I2AV.sh
Normal file
43
examples/mova/model_training/lora/MOVA-360P-I2AV.sh
Normal file
@@ -0,0 +1,43 @@
|
||||
accelerate launch examples/mova/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 352 \
|
||||
--width 640 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
--output_path "./models/train/MOVA-360p-I2AV_high_noise_lora" \
|
||||
--lora_base_model "video_dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0 \
|
||||
--use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
# accelerate launch examples/mova/model_training/train.py \
|
||||
# --dataset_base_path data/example_video_dataset/ltx2 \
|
||||
# --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
# --data_file_keys "video,input_audio" \
|
||||
# --extra_inputs "input_audio,input_image" \
|
||||
# --height 352 \
|
||||
# --width 640 \
|
||||
# --num_frames 121 \
|
||||
# --dataset_repeat 100 \
|
||||
# --model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
# --output_path "./models/train/MOVA-360p-I2AV_low_noise_lora" \
|
||||
# --lora_base_model "video_dit" \
|
||||
# --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
# --lora_rank 32 \
|
||||
# --max_timestep_boundary 1 \
|
||||
# --min_timestep_boundary 0.358 \
|
||||
# --use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [0, 900)
|
||||
43
examples/mova/model_training/lora/MOVA-720P-I2AV.sh
Normal file
43
examples/mova/model_training/lora/MOVA-720P-I2AV.sh
Normal file
@@ -0,0 +1,43 @@
|
||||
accelerate launch examples/mova/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 720 \
|
||||
--width 1280 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
--output_path "./models/train/MOVA-720p-I2AV_high_noise_lora" \
|
||||
--lora_base_model "video_dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0 \
|
||||
--use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch examples/mova/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 720 \
|
||||
--width 1280 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.video_dit." \
|
||||
--output_path "./models/train/MOVA-720p-I2AV_low_noise_lora" \
|
||||
--lora_base_model "video_dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358 \
|
||||
--use_gradient_checkpointing
|
||||
# boundary corresponds to timesteps [0, 900)
|
||||
193
examples/mova/model_training/train.py
Normal file
193
examples/mova/model_training/train.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import torch, os, argparse, accelerate, warnings
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess
|
||||
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class MOVATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
fp8_models=None,
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
max_timestep_boundary=1.0,
|
||||
min_timestep_boundary=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
# Warning
|
||||
if not use_gradient_checkpointing:
|
||||
warnings.warn("Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.")
|
||||
use_gradient_checkpointing = True
|
||||
|
||||
# Load models
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
self.pipe = MovaAudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = self.split_pipeline_units(
|
||||
task, self.pipe, trainable_models, lora_base_model,
|
||||
remove_unnecessary_params=True,
|
||||
force_remove_params_shared=("audio_latents", "video_latents"),
|
||||
force_remove_params_nega=("audio_context", "video_context")
|
||||
)
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||
preset_lora_path, preset_lora_model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.fp8_models = fp8_models
|
||||
self.task = task
|
||||
self.task_to_loss = {
|
||||
"sft:data_process": lambda pipe, *args: args,
|
||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
}
|
||||
self.max_timestep_boundary = max_timestep_boundary
|
||||
self.min_timestep_boundary = min_timestep_boundary
|
||||
|
||||
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||
for extra_input in extra_inputs:
|
||||
if extra_input == "input_image":
|
||||
inputs_shared["input_image"] = data["video"][0]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
return inputs_shared
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {}
|
||||
inputs_shared = {
|
||||
# Assume you are using this pipeline for inference,
|
||||
# please fill in the input parameters.
|
||||
"input_video": data["video"],
|
||||
"height": data["video"][0].size[1],
|
||||
"width": data["video"][0].size[0],
|
||||
"num_frames": len(data["video"]),
|
||||
"frame_rate": data.get("frame_rate", 24),
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
"tiled": False,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
"max_timestep_boundary": self.max_timestep_boundary,
|
||||
"min_timestep_boundary": self.min_timestep_boundary,
|
||||
}
|
||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
for unit in self.pipe.units:
|
||||
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||
return loss
|
||||
|
||||
|
||||
def ltx2_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser = add_general_config(parser)
|
||||
parser = add_video_size_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||
parser.add_argument("--frame_rate", type=float, default=24, help="Frame rate of the training videos. Mova is trained with a frame rate of 24, so it's recommended to use the same frame rate.")
|
||||
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ltx2_parser()
|
||||
args = parser.parse_args()
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
model = MOVATrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
preset_lora_model=args.preset_lora_model,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
fp8_models=args.fp8_models,
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||
max_timestep_boundary=args.max_timestep_boundary,
|
||||
min_timestep_boundary=args.min_timestep_boundary,
|
||||
)
|
||||
video_processor = UnifiedDataset.default_video_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=model.pipe.height_division_factor,
|
||||
width_division_factor=model.pipe.width_division_factor,
|
||||
num_frames=args.num_frames,
|
||||
time_division_factor=model.pipe.time_division_factor,
|
||||
time_division_remainder=model.pipe.time_division_remainder,
|
||||
frame_rate=args.frame_rate,
|
||||
fix_frame_rate=True,
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=video_processor,
|
||||
special_operator_map={
|
||||
"input_audio":
|
||||
ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(
|
||||
num_frames=args.num_frames,
|
||||
time_division_factor=model.pipe.time_division_factor,
|
||||
time_division_remainder=model.pipe.time_division_remainder,
|
||||
frame_rate=args.frame_rate,
|
||||
),
|
||||
"in_context_videos":
|
||||
RouteByType(operator_map=[
|
||||
(str, video_processor),
|
||||
(list, SequencialProcess(video_processor)),
|
||||
]),
|
||||
},
|
||||
)
|
||||
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
)
|
||||
launcher_map = {
|
||||
"sft:data_process": launch_data_process_task,
|
||||
"direct_distill:data_process": launch_data_process_task,
|
||||
"sft": launch_training_task,
|
||||
"sft:train": launch_training_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"direct_distill:train": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||
53
examples/mova/model_training/validate_full/MOVA-360p-I2AV.py
Normal file
53
examples/mova/model_training/validate_full/MOVA-360p-I2AV.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(path="./models/train/MOVA-360p-I2AV_high_noise_full/epoch-4.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
prompt = "A beautiful sunset over the ocean."
|
||||
height, width, num_frames = 352, 640, 121
|
||||
frame_rate = 24
|
||||
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-360p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
54
examples/mova/model_training/validate_full/MOVA-720p-I2AV.py
Normal file
54
examples/mova/model_training/validate_full/MOVA-720p-I2AV.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(path="./models/train/MOVA-720p-I2AV_high_noise_full/epoch-4.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.video_dit, "models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors")
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
prompt = "A beautiful sunset over the ocean."
|
||||
height, width, num_frames = 720, 1280, 121
|
||||
frame_rate = 24
|
||||
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-720p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
54
examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py
Normal file
54
examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.video_dit, "models/train/MOVA-360p-I2AV_high_noise_lora/epoch-4.safetensors")
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
prompt = "A beautiful sunset over the ocean."
|
||||
height, width, num_frames = 352, 640, 121
|
||||
frame_rate = 24
|
||||
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-360p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
54
examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py
Normal file
54
examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data.audio_video import write_video_audio
|
||||
from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = MovaAudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.video_dit, "models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors")
|
||||
negative_prompt = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
|
||||
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指"
|
||||
)
|
||||
prompt = "A beautiful sunset over the ocean."
|
||||
height, width, num_frames = 720, 1280, 121
|
||||
frame_rate = 24
|
||||
input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
|
||||
# Image-to-video
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
input_image=input_image,
|
||||
num_inference_steps=50,
|
||||
seed=0,
|
||||
tiled=True,
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
write_video_audio(video, audio, "MOVA-720p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate)
|
||||
@@ -12,7 +12,6 @@ requires-python = ">=3.10.1"
|
||||
dependencies = [
|
||||
"torch>=2.0.0",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"transformers",
|
||||
"imageio",
|
||||
"imageio[ffmpeg]",
|
||||
@@ -48,6 +47,10 @@ npu = [
|
||||
"torch-npu==2.7.1",
|
||||
"torchvision==0.22.1+cpu"
|
||||
]
|
||||
audio = [
|
||||
"torchaudio",
|
||||
"torchcodec"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
Reference in New Issue
Block a user