mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
@@ -201,6 +201,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
@@ -372,6 +373,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
|||||||
|
|
||||||
|
|
||||||
## Update History
|
## Update History
|
||||||
|
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||||
|
|
||||||
- **August 21, 2025**: [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) is released! Compared to the V1 version, the training dataset has been updated to the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), enabling generated images to better align with the inherent image distribution and style of Qwen-Image. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
|
- **August 21, 2025**: [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) is released! Compared to the V1 version, the training dataset has been updated to the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), enabling generated images to better align with the inherent image distribution and style of Qwen-Image. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
|
||||||
|
|
||||||
- **August 21, 2025**: We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structure control LoRA model. Following "In Context" routine, it supports various types of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
|
- **August 21, 2025**: We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structure control LoRA model. Following "In Context" routine, it supports various types of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
|
||||||
|
|||||||
@@ -201,6 +201,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
@@ -388,6 +389,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
|||||||
|
|
||||||
|
|
||||||
## 更新历史
|
## 更新历史
|
||||||
|
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
||||||
|
|
||||||
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
||||||
|
|||||||
@@ -56,11 +56,13 @@ from ..models.stepvideo_vae import StepVideoVAE
|
|||||||
from ..models.stepvideo_dit import StepVideoModel
|
from ..models.stepvideo_dit import StepVideoModel
|
||||||
|
|
||||||
from ..models.wan_video_dit import WanModel
|
from ..models.wan_video_dit import WanModel
|
||||||
|
from ..models.wan_video_dit_s2v import WanS2VModel
|
||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
from ..models.wan_video_vace import VaceWanModel
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
|
from ..models.wav2vec import WanS2VAudioEncoder
|
||||||
|
|
||||||
from ..models.step1x_connector import Qwen2Connector
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
|
|
||||||
@@ -155,6 +157,7 @@ model_loader_configs = [
|
|||||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||||
|
(None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"),
|
||||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
@@ -172,6 +175,7 @@ model_loader_configs = [
|
|||||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||||
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||||
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||||
|
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .video import VideoData, save_video, save_frames
|
from .video import VideoData, save_video, save_frames, merge_video_audio, save_video_with_audio
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import imageio, os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
class LowMemoryVideo:
|
class LowMemoryVideo:
|
||||||
@@ -146,3 +148,70 @@ def save_frames(frames, save_path):
|
|||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
||||||
frame.save(os.path.join(save_path, f"{i}.png"))
|
frame.save(os.path.join(save_path, f"{i}.png"))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_video_audio(video_path: str, audio_path: str):
|
||||||
|
# TODO: may need a in-python implementation to avoid subprocess dependency
|
||||||
|
"""
|
||||||
|
Merge the video and audio into a new video, with the duration set to the shorter of the two,
|
||||||
|
and overwrite the original video file.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
video_path (str): Path to the original video file
|
||||||
|
audio_path (str): Path to the audio file
|
||||||
|
"""
|
||||||
|
|
||||||
|
# check
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
raise FileNotFoundError(f"video file {video_path} does not exist")
|
||||||
|
if not os.path.exists(audio_path):
|
||||||
|
raise FileNotFoundError(f"audio file {audio_path} does not exist")
|
||||||
|
|
||||||
|
base, ext = os.path.splitext(video_path)
|
||||||
|
temp_output = f"{base}_temp{ext}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# create ffmpeg command
|
||||||
|
command = [
|
||||||
|
'ffmpeg',
|
||||||
|
'-y', # overwrite
|
||||||
|
'-i',
|
||||||
|
video_path,
|
||||||
|
'-i',
|
||||||
|
audio_path,
|
||||||
|
'-c:v',
|
||||||
|
'copy', # copy video stream
|
||||||
|
'-c:a',
|
||||||
|
'aac', # use AAC audio encoder
|
||||||
|
'-b:a',
|
||||||
|
'192k', # set audio bitrate (optional)
|
||||||
|
'-map',
|
||||||
|
'0:v:0', # select the first video stream
|
||||||
|
'-map',
|
||||||
|
'1:a:0', # select the first audio stream
|
||||||
|
'-shortest', # choose the shortest duration
|
||||||
|
temp_output
|
||||||
|
]
|
||||||
|
|
||||||
|
# execute the command
|
||||||
|
result = subprocess.run(
|
||||||
|
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
|
||||||
|
# check result
|
||||||
|
if result.returncode != 0:
|
||||||
|
error_msg = f"FFmpeg execute failed: {result.stderr}"
|
||||||
|
print(error_msg)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
shutil.move(temp_output, video_path)
|
||||||
|
print(f"Merge completed, saved to {video_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if os.path.exists(temp_output):
|
||||||
|
os.remove(temp_output)
|
||||||
|
print(f"merge_video_audio failed with error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
|
||||||
|
save_video(frames, save_path, fps, quality, ffmpeg_params)
|
||||||
|
merge_video_audio(save_path, audio_path)
|
||||||
|
|||||||
627
diffsynth/models/wan_video_dit_s2v.py
Normal file
627
diffsynth/models/wan_video_dit_s2v.py
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple
|
||||||
|
from .utils import hash_state_dict_keys
|
||||||
|
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
|
||||||
|
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||||
|
module_names, modules = [], []
|
||||||
|
current_name = parent_name if parent_name else 'root'
|
||||||
|
module_names.append(current_name)
|
||||||
|
modules.append(model)
|
||||||
|
|
||||||
|
for name, child in model.named_children():
|
||||||
|
if parent_name:
|
||||||
|
child_name = f'{parent_name}.{name}'
|
||||||
|
else:
|
||||||
|
child_name = name
|
||||||
|
child_modules, child_names = torch_dfs(child, child_name)
|
||||||
|
module_names += child_names
|
||||||
|
modules += child_modules
|
||||||
|
return modules, module_names
|
||||||
|
|
||||||
|
|
||||||
|
def rope_precompute(x, grid_sizes, freqs, start=None):
|
||||||
|
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
||||||
|
|
||||||
|
# split freqs
|
||||||
|
if type(freqs) is list:
|
||||||
|
trainable_freqs = freqs[1]
|
||||||
|
freqs = freqs[0]
|
||||||
|
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||||
|
|
||||||
|
# loop over samples
|
||||||
|
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64))
|
||||||
|
seq_bucket = [0]
|
||||||
|
if not type(grid_sizes) is list:
|
||||||
|
grid_sizes = [grid_sizes]
|
||||||
|
for g in grid_sizes:
|
||||||
|
if not type(g) is list:
|
||||||
|
g = [torch.zeros_like(g), g]
|
||||||
|
batch_size = g[0].shape[0]
|
||||||
|
for i in range(batch_size):
|
||||||
|
if start is None:
|
||||||
|
f_o, h_o, w_o = g[0][i]
|
||||||
|
else:
|
||||||
|
f_o, h_o, w_o = start[i]
|
||||||
|
|
||||||
|
f, h, w = g[1][i]
|
||||||
|
t_f, t_h, t_w = g[2][i]
|
||||||
|
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
|
||||||
|
seq_len = int(seq_f * seq_h * seq_w)
|
||||||
|
if seq_len > 0:
|
||||||
|
if t_f > 0:
|
||||||
|
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item()
|
||||||
|
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
|
||||||
|
if f_o >= 0:
|
||||||
|
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
|
||||||
|
else:
|
||||||
|
f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
|
||||||
|
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
|
||||||
|
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
|
||||||
|
|
||||||
|
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
|
||||||
|
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
|
||||||
|
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
|
||||||
|
|
||||||
|
freqs_i = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_0.expand(seq_f, seq_h, seq_w, -1),
|
||||||
|
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||||
|
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||||
|
],
|
||||||
|
dim=-1
|
||||||
|
).reshape(seq_len, 1, -1)
|
||||||
|
elif t_f < 0:
|
||||||
|
freqs_i = trainable_freqs.unsqueeze(1)
|
||||||
|
# apply rotary embedding
|
||||||
|
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
|
||||||
|
seq_bucket.append(seq_bucket[-1] + seq_len)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
padding = (kernel_size - 1, 0) # T
|
||||||
|
self.time_causal_padding = padding
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MotionEncoder_tc(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.need_global = need_global
|
||||||
|
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
|
||||||
|
if need_global:
|
||||||
|
self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
|
||||||
|
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
|
||||||
|
|
||||||
|
if need_global:
|
||||||
|
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x_ori = x.clone()
|
||||||
|
b, c, t = x.shape
|
||||||
|
x = self.conv1_local(x)
|
||||||
|
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||||
|
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype)
|
||||||
|
x = torch.cat([x, padding], dim=-2)
|
||||||
|
x_local = x.clone()
|
||||||
|
|
||||||
|
if not self.need_global:
|
||||||
|
return x_local
|
||||||
|
|
||||||
|
x = self.conv1_global(x_ori)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.final_linear(x)
|
||||||
|
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||||
|
|
||||||
|
return x, x_local
|
||||||
|
|
||||||
|
|
||||||
|
class FramePackMotioner(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||||
|
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
||||||
|
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
||||||
|
self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)
|
||||||
|
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1)
|
||||||
|
self.drop_mode = drop_mode
|
||||||
|
|
||||||
|
def forward(self, motion_latents, add_last_motion=2):
|
||||||
|
motion_frames = motion_latents[0].shape[1]
|
||||||
|
mot = []
|
||||||
|
mot_remb = []
|
||||||
|
for m in motion_latents:
|
||||||
|
lat_height, lat_width = m.shape[2], m.shape[3]
|
||||||
|
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype)
|
||||||
|
overlap_frame = min(padd_lat.shape[1], m.shape[1])
|
||||||
|
if overlap_frame > 0:
|
||||||
|
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
|
||||||
|
|
||||||
|
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||||
|
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum()
|
||||||
|
padd_lat[:, -zero_end_frame:] = 0
|
||||||
|
|
||||||
|
padd_lat = padd_lat.unsqueeze(0)
|
||||||
|
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split(
|
||||||
|
list(self.zip_frame_buckets)[::-1], dim=2
|
||||||
|
) # 16, 2 ,1
|
||||||
|
|
||||||
|
# patchfy
|
||||||
|
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||||
|
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)
|
||||||
|
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||||
|
clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post
|
||||||
|
clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x
|
||||||
|
|
||||||
|
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||||
|
|
||||||
|
# rope
|
||||||
|
start_time_id = -(self.zip_frame_buckets[:1].sum())
|
||||||
|
end_time_id = start_time_id + self.zip_frame_buckets[0]
|
||||||
|
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
|
||||||
|
[
|
||||||
|
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||||
|
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||||
|
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||||
|
]
|
||||||
|
|
||||||
|
start_time_id = -(self.zip_frame_buckets[:2].sum())
|
||||||
|
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
|
||||||
|
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
|
||||||
|
[
|
||||||
|
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||||
|
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
|
||||||
|
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||||
|
]
|
||||||
|
|
||||||
|
start_time_id = -(self.zip_frame_buckets[:3].sum())
|
||||||
|
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
|
||||||
|
grid_sizes_4x = [
|
||||||
|
[
|
||||||
|
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||||
|
torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1),
|
||||||
|
torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
|
||||||
|
|
||||||
|
motion_rope_emb = rope_precompute(
|
||||||
|
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads),
|
||||||
|
grid_sizes,
|
||||||
|
self.freqs,
|
||||||
|
start=None
|
||||||
|
)
|
||||||
|
|
||||||
|
mot.append(motion_lat)
|
||||||
|
mot_remb.append(motion_rope_emb)
|
||||||
|
return mot, mot_remb
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||||
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
temb = self.linear(F.silu(temb))
|
||||||
|
shift, scale = temb.chunk(2, dim=1)
|
||||||
|
shift = shift[:, None, :]
|
||||||
|
scale = scale[:, None, :]
|
||||||
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioInjector_WAN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
all_modules,
|
||||||
|
all_modules_names,
|
||||||
|
dim=2048,
|
||||||
|
num_heads=32,
|
||||||
|
inject_layer=[0, 27],
|
||||||
|
enable_adain=False,
|
||||||
|
adain_dim=2048,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.injected_block_id = {}
|
||||||
|
audio_injector_id = 0
|
||||||
|
for mod_name, mod in zip(all_modules_names, all_modules):
|
||||||
|
if isinstance(mod, DiTBlock):
|
||||||
|
for inject_id in inject_layer:
|
||||||
|
if f'transformer_blocks.{inject_id}' in mod_name:
|
||||||
|
self.injected_block_id[inject_id] = audio_injector_id
|
||||||
|
audio_injector_id += 1
|
||||||
|
|
||||||
|
self.injector = nn.ModuleList([CrossAttention(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
) for _ in range(audio_injector_id)])
|
||||||
|
self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm(
|
||||||
|
dim,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
) for _ in range(audio_injector_id)])
|
||||||
|
self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm(
|
||||||
|
dim,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
) for _ in range(audio_injector_id)])
|
||||||
|
if enable_adain:
|
||||||
|
self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)])
|
||||||
|
|
||||||
|
|
||||||
|
class CausalAudioEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global)
|
||||||
|
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
|
||||||
|
|
||||||
|
self.weights = torch.nn.Parameter(weight)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
# features B * num_layers * dim * video_length
|
||||||
|
weights = self.act(self.weights.to(device=features.device, dtype=features.dtype))
|
||||||
|
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||||
|
weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
|
||||||
|
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||||
|
res = self.encoder(weighted_feat) # b f n dim
|
||||||
|
return res # b f n dim
|
||||||
|
|
||||||
|
|
||||||
|
class WanS2VDiTBlock(DiTBlock):
|
||||||
|
|
||||||
|
def forward(self, x, context, t_mod, seq_len_x, freqs):
|
||||||
|
t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||||
|
# t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.
|
||||||
|
t_mod = [
|
||||||
|
torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)
|
||||||
|
for element in t_mod
|
||||||
|
]
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod
|
||||||
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||||
|
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||||
|
x = x + self.cross_attn(self.norm3(x), context)
|
||||||
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanS2VModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
in_dim: int,
|
||||||
|
ffn_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
text_dim: int,
|
||||||
|
freq_dim: int,
|
||||||
|
eps: float,
|
||||||
|
patch_size: Tuple[int, int, int],
|
||||||
|
num_heads: int,
|
||||||
|
num_layers: int,
|
||||||
|
cond_dim: int,
|
||||||
|
audio_dim: int,
|
||||||
|
num_audio_token: int,
|
||||||
|
enable_adain: bool = True,
|
||||||
|
audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||||
|
zero_timestep: bool = True,
|
||||||
|
add_last_motion: bool = True,
|
||||||
|
framepack_drop_mode: str = "padd",
|
||||||
|
fuse_vae_embedding_in_latents: bool = True,
|
||||||
|
require_vae_embedding: bool = False,
|
||||||
|
seperated_timestep: bool = False,
|
||||||
|
require_clip_embedding: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.freq_dim = freq_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.enbale_adain = enable_adain
|
||||||
|
self.add_last_motion = add_last_motion
|
||||||
|
self.zero_timestep = zero_timestep
|
||||||
|
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||||
|
self.require_vae_embedding = require_vae_embedding
|
||||||
|
self.seperated_timestep = seperated_timestep
|
||||||
|
self.require_clip_embedding = require_clip_embedding
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim))
|
||||||
|
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
|
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
|
||||||
|
self.head = Head(dim, out_dim, patch_size, eps)
|
||||||
|
self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)
|
||||||
|
|
||||||
|
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
|
||||||
|
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
||||||
|
# TODO: refactor dfs
|
||||||
|
self.audio_injector = AudioInjector_WAN(
|
||||||
|
all_modules,
|
||||||
|
all_modules_names,
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
inject_layer=audio_inject_layers,
|
||||||
|
enable_adain=enable_adain,
|
||||||
|
adain_dim=dim,
|
||||||
|
)
|
||||||
|
self.trainable_cond_mask = nn.Embedding(3, dim)
|
||||||
|
self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode)
|
||||||
|
|
||||||
|
def patchify(self, x: torch.Tensor):
|
||||||
|
grid_size = x.shape[2:]
|
||||||
|
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||||
|
return x, grid_size # x, grid_size: (f, h, w)
|
||||||
|
|
||||||
|
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||||
|
f=grid_size[0],
|
||||||
|
h=grid_size[1],
|
||||||
|
w=grid_size[2],
|
||||||
|
x=self.patch_size[0],
|
||||||
|
y=self.patch_size[1],
|
||||||
|
z=self.patch_size[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):
|
||||||
|
flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)
|
||||||
|
if drop_motion_frames:
|
||||||
|
return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb]
|
||||||
|
else:
|
||||||
|
return flattern_mot, mot_remb
|
||||||
|
|
||||||
|
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||||
|
# inject the motion frames token to the hidden states
|
||||||
|
# TODO: check drop_motion_frames = False
|
||||||
|
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
|
||||||
|
if len(mot) > 0:
|
||||||
|
x = torch.cat([x, mot[0]], dim=1)
|
||||||
|
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
|
||||||
|
mask_input = torch.cat(
|
||||||
|
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
|
||||||
|
)
|
||||||
|
return x, rope_embs, mask_input
|
||||||
|
|
||||||
|
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False):
|
||||||
|
if block_idx in self.audio_injector.injected_block_id.keys():
|
||||||
|
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
||||||
|
num_frames = audio_emb.shape[1]
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
from xfuser.core.distributed import get_sp_group
|
||||||
|
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
||||||
|
|
||||||
|
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
|
||||||
|
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
||||||
|
|
||||||
|
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||||
|
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||||
|
attn_hidden_states = adain_hidden_states
|
||||||
|
|
||||||
|
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||||
|
attn_audio_emb = audio_emb
|
||||||
|
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
|
||||||
|
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||||
|
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank
|
||||||
|
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
|
||||||
|
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
||||||
|
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
|
||||||
|
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
||||||
|
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
||||||
|
return audio_emb_global, merged_audio_emb
|
||||||
|
|
||||||
|
def get_grid_sizes(self, grid_size_x, grid_size_ref):
|
||||||
|
f, h, w = grid_size_x
|
||||||
|
rf, rh, rw = grid_size_ref
|
||||||
|
grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
|
||||||
|
grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]
|
||||||
|
grid_sizes_ref = [[
|
||||||
|
torch.tensor([30, 0, 0]).unsqueeze(0),
|
||||||
|
torch.tensor([31, rh, rw]).unsqueeze(0),
|
||||||
|
torch.tensor([1, rh, rw]).unsqueeze(0),
|
||||||
|
]]
|
||||||
|
return grid_sizes_x + grid_sizes_ref
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
audio_input,
|
||||||
|
motion_latents,
|
||||||
|
pose_cond,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
use_gradient_checkpointing=False
|
||||||
|
):
|
||||||
|
origin_ref_latents = latents[:, :, 0:1]
|
||||||
|
x = latents[:, :, 1:]
|
||||||
|
|
||||||
|
# context embedding
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
# audio encode
|
||||||
|
audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)
|
||||||
|
|
||||||
|
# x and pose_cond
|
||||||
|
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||||
|
x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
|
||||||
|
seq_len_x = x.shape[1]
|
||||||
|
|
||||||
|
# reference image
|
||||||
|
ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
|
||||||
|
grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||||
|
x = torch.cat([x, ref_latents], dim=1)
|
||||||
|
# mask
|
||||||
|
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||||
|
# freqs
|
||||||
|
pre_compute_freqs = rope_precompute(
|
||||||
|
x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None
|
||||||
|
)
|
||||||
|
# motion
|
||||||
|
x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||||
|
|
||||||
|
x = x + self.trainable_cond_mask(mask).to(x.dtype)
|
||||||
|
|
||||||
|
# t_mod
|
||||||
|
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||||
|
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
for block_id, block in enumerate(self.blocks):
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x,
|
||||||
|
context,
|
||||||
|
t_mod,
|
||||||
|
seq_len_x,
|
||||||
|
pre_compute_freqs[0],
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||||
|
x,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
elif use_gradient_checkpointing:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x,
|
||||||
|
context,
|
||||||
|
t_mod,
|
||||||
|
seq_len_x,
|
||||||
|
pre_compute_freqs[0],
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||||
|
x,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||||
|
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||||
|
|
||||||
|
x = x[:, :seq_len_x]
|
||||||
|
x = self.head(x, t[:-1])
|
||||||
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
# make compatible with wan video
|
||||||
|
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanS2VModelStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class WanS2VModelStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
config = {}
|
||||||
|
if hash_state_dict_keys(state_dict) == "966cffdcc52f9c46c391768b27637614":
|
||||||
|
config = {
|
||||||
|
"dim": 5120,
|
||||||
|
"in_dim": 16,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"patch_size": (1, 2, 2),
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"cond_dim": 16,
|
||||||
|
"audio_dim": 1024,
|
||||||
|
"num_audio_token": 4,
|
||||||
|
}
|
||||||
|
return state_dict, config
|
||||||
197
diffsynth/models/wav2vec.py
Normal file
197
diffsynth/models/wav2vec.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
|
||||||
|
required_duration = num_sample / target_fps
|
||||||
|
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
||||||
|
if required_duration > total_frames / original_fps:
|
||||||
|
raise ValueError("required_duration must be less than video length")
|
||||||
|
|
||||||
|
if not fixed_start is None and fixed_start >= 0:
|
||||||
|
start_frame = fixed_start
|
||||||
|
else:
|
||||||
|
max_start = total_frames - required_origin_frames
|
||||||
|
if max_start < 0:
|
||||||
|
raise ValueError("video length is too short")
|
||||||
|
start_frame = np.random.randint(0, max_start + 1)
|
||||||
|
start_time = start_frame / original_fps
|
||||||
|
|
||||||
|
end_time = start_time + required_duration
|
||||||
|
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
||||||
|
|
||||||
|
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
||||||
|
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
||||||
|
return frame_indices
|
||||||
|
|
||||||
|
|
||||||
|
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||||
|
"""
|
||||||
|
features: shape=[1, T, 512]
|
||||||
|
input_fps: fps for audio, f_a
|
||||||
|
output_fps: fps for video, f_m
|
||||||
|
output_len: video length
|
||||||
|
"""
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
seq_len = features.shape[2] / float(input_fps)
|
||||||
|
if output_len is None:
|
||||||
|
output_len = int(seq_len * output_fps)
|
||||||
|
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len]
|
||||||
|
return output_features.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class WanS2VAudioEncoder(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config
|
||||||
|
config = {
|
||||||
|
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
|
||||||
|
"activation_dropout": 0.05,
|
||||||
|
"apply_spec_augment": True,
|
||||||
|
"architectures": ["Wav2Vec2ForCTC"],
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"conv_bias": True,
|
||||||
|
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
|
||||||
|
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
|
||||||
|
"conv_stride": [5, 2, 2, 2, 2, 2, 2],
|
||||||
|
"ctc_loss_reduction": "mean",
|
||||||
|
"ctc_zero_infinity": True,
|
||||||
|
"do_stable_layer_norm": True,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"feat_extract_activation": "gelu",
|
||||||
|
"feat_extract_dropout": 0.0,
|
||||||
|
"feat_extract_norm": "layer",
|
||||||
|
"feat_proj_dropout": 0.05,
|
||||||
|
"final_dropout": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout": 0.05,
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"layerdrop": 0.05,
|
||||||
|
"mask_channel_length": 10,
|
||||||
|
"mask_channel_min_space": 1,
|
||||||
|
"mask_channel_other": 0.0,
|
||||||
|
"mask_channel_prob": 0.0,
|
||||||
|
"mask_channel_selection": "static",
|
||||||
|
"mask_feature_length": 10,
|
||||||
|
"mask_feature_prob": 0.0,
|
||||||
|
"mask_time_length": 10,
|
||||||
|
"mask_time_min_space": 1,
|
||||||
|
"mask_time_other": 0.0,
|
||||||
|
"mask_time_prob": 0.05,
|
||||||
|
"mask_time_selection": "static",
|
||||||
|
"model_type": "wav2vec2",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_conv_pos_embedding_groups": 16,
|
||||||
|
"num_conv_pos_embeddings": 128,
|
||||||
|
"num_feat_extract_layers": 7,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"transformers_version": "4.7.0.dev0",
|
||||||
|
"vocab_size": 33
|
||||||
|
}
|
||||||
|
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
||||||
|
self.video_rate = 30
|
||||||
|
|
||||||
|
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):
|
||||||
|
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# retrieve logits & take argmax
|
||||||
|
res = self.model(input_values, output_hidden_states=True)
|
||||||
|
if return_all_layers:
|
||||||
|
feat = torch.cat(res.hidden_states)
|
||||||
|
else:
|
||||||
|
feat = res.hidden_states[-1]
|
||||||
|
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
||||||
|
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||||
|
|
||||||
|
if num_layers > 1:
|
||||||
|
return_all_layers = True
|
||||||
|
else:
|
||||||
|
return_all_layers = False
|
||||||
|
|
||||||
|
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
|
||||||
|
|
||||||
|
bucket_num = min_batch_num * batch_frames
|
||||||
|
batch_idx = [stride * i for i in range(bucket_num)]
|
||||||
|
batch_audio_eb = []
|
||||||
|
for bi in batch_idx:
|
||||||
|
if bi < audio_frame_num:
|
||||||
|
audio_sample_stride = 2
|
||||||
|
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||||
|
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||||
|
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||||
|
|
||||||
|
if return_all_layers:
|
||||||
|
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||||
|
else:
|
||||||
|
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||||
|
else:
|
||||||
|
frame_audio_embed = \
|
||||||
|
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||||
|
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||||
|
batch_audio_eb.append(frame_audio_embed)
|
||||||
|
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||||
|
|
||||||
|
return batch_audio_eb, min_batch_num
|
||||||
|
|
||||||
|
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
|
||||||
|
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||||
|
|
||||||
|
if num_layers > 1:
|
||||||
|
return_all_layers = True
|
||||||
|
else:
|
||||||
|
return_all_layers = False
|
||||||
|
|
||||||
|
scale = self.video_rate / fps
|
||||||
|
|
||||||
|
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
||||||
|
|
||||||
|
bucket_num = min_batch_num * batch_frames
|
||||||
|
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
|
||||||
|
batch_idx = get_sample_indices(
|
||||||
|
original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0
|
||||||
|
)
|
||||||
|
batch_audio_eb = []
|
||||||
|
audio_sample_stride = int(self.video_rate / fps)
|
||||||
|
for bi in batch_idx:
|
||||||
|
if bi < audio_frame_num:
|
||||||
|
|
||||||
|
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||||
|
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||||
|
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||||
|
|
||||||
|
if return_all_layers:
|
||||||
|
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||||
|
else:
|
||||||
|
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||||
|
else:
|
||||||
|
frame_audio_embed = \
|
||||||
|
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||||
|
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||||
|
batch_audio_eb.append(frame_audio_embed)
|
||||||
|
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||||
|
|
||||||
|
return batch_audio_eb, min_batch_num
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanS2VAudioEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class WanS2VAudioEncoderStateDictConverter():
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict = {'model.' + k: v for k, v in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
@@ -15,6 +15,7 @@ from typing_extensions import Literal
|
|||||||
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
||||||
from ..models import ModelManager, load_state_dict
|
from ..models import ModelManager, load_state_dict
|
||||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||||
|
from ..models.wan_video_dit_s2v import rope_precompute
|
||||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
@@ -49,8 +50,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.units = [
|
self.units = [
|
||||||
WanVideoUnit_ShapeChecker(),
|
WanVideoUnit_ShapeChecker(),
|
||||||
WanVideoUnit_NoiseInitializer(),
|
WanVideoUnit_NoiseInitializer(),
|
||||||
WanVideoUnit_InputVideoEmbedder(),
|
|
||||||
WanVideoUnit_PromptEmbedder(),
|
WanVideoUnit_PromptEmbedder(),
|
||||||
|
WanVideoUnit_S2V(),
|
||||||
|
WanVideoUnit_InputVideoEmbedder(),
|
||||||
WanVideoUnit_ImageEmbedderVAE(),
|
WanVideoUnit_ImageEmbedderVAE(),
|
||||||
WanVideoUnit_ImageEmbedderCLIP(),
|
WanVideoUnit_ImageEmbedderCLIP(),
|
||||||
WanVideoUnit_ImageEmbedderFused(),
|
WanVideoUnit_ImageEmbedderFused(),
|
||||||
@@ -127,6 +129,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||||
RMSNorm: AutoWrappedModule,
|
RMSNorm: AutoWrappedModule,
|
||||||
torch.nn.Conv2d: AutoWrappedModule,
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
|
torch.nn.Conv1d: AutoWrappedModule,
|
||||||
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
},
|
},
|
||||||
module_config = dict(
|
module_config = dict(
|
||||||
offload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
@@ -254,6 +258,25 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
),
|
),
|
||||||
vram_limit=vram_limit,
|
vram_limit=vram_limit,
|
||||||
)
|
)
|
||||||
|
if self.audio_encoder is not None:
|
||||||
|
# TODO: need check
|
||||||
|
dtype = next(iter(self.audio_encoder.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.audio_encoder,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
|
torch.nn.Conv1d: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def initialize_usp(self):
|
def initialize_usp(self):
|
||||||
@@ -290,6 +313,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
device: Union[str, torch.device] = "cuda",
|
device: Union[str, torch.device] = "cuda",
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||||
|
audio_processor_config: ModelConfig = None,
|
||||||
redirect_common_files: bool = True,
|
redirect_common_files: bool = True,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
):
|
):
|
||||||
@@ -332,7 +356,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||||
|
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||||
|
|
||||||
# Size division factor
|
# Size division factor
|
||||||
if pipe.vae is not None:
|
if pipe.vae is not None:
|
||||||
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
||||||
@@ -342,7 +367,11 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
||||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||||
|
|
||||||
|
if audio_processor_config is not None:
|
||||||
|
audio_processor_config.download_if_necessary(use_usp=use_usp)
|
||||||
|
from transformers import Wav2Vec2Processor
|
||||||
|
pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
|
||||||
# Unified Sequence Parallel
|
# Unified Sequence Parallel
|
||||||
if use_usp: pipe.enable_usp()
|
if use_usp: pipe.enable_usp()
|
||||||
return pipe
|
return pipe
|
||||||
@@ -361,6 +390,10 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Video-to-video
|
# Video-to-video
|
||||||
input_video: Optional[list[Image.Image]] = None,
|
input_video: Optional[list[Image.Image]] = None,
|
||||||
denoising_strength: Optional[float] = 1.0,
|
denoising_strength: Optional[float] = 1.0,
|
||||||
|
# Speech-to-video
|
||||||
|
input_audio: Optional[str] = None,
|
||||||
|
audio_sample_rate: Optional[int] = 16000,
|
||||||
|
s2v_pose_video: Optional[list[Image.Image]] = None,
|
||||||
# ControlNet
|
# ControlNet
|
||||||
control_video: Optional[list[Image.Image]] = None,
|
control_video: Optional[list[Image.Image]] = None,
|
||||||
reference_image: Optional[Image.Image] = None,
|
reference_image: Optional[Image.Image] = None,
|
||||||
@@ -429,6 +462,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
"motion_bucket_id": motion_bucket_id,
|
"motion_bucket_id": motion_bucket_id,
|
||||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||||
|
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video,
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -868,6 +902,66 @@ class WanVideoUnit_CfgMerger(PipelineUnit):
|
|||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_S2V(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
take_over=True,
|
||||||
|
onload_model_names=("audio_encoder", "vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
|
||||||
|
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(["audio_encoder"])
|
||||||
|
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True, dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
|
||||||
|
z, fps=16, batch_frames=num_frames - 1, m=0
|
||||||
|
)
|
||||||
|
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype)
|
||||||
|
if len(audio_embed_bucket.shape) == 3:
|
||||||
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
||||||
|
elif len(audio_embed_bucket.shape) == 4:
|
||||||
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
||||||
|
audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1]
|
||||||
|
return {"audio_input": audio_embed_bucket}
|
||||||
|
|
||||||
|
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
|
||||||
|
pipe.load_models_to_device(["vae"])
|
||||||
|
# TODO: may support input motion latents, which related to `drop_motion_frames = False`
|
||||||
|
motion_frames = 73
|
||||||
|
lat_motion_frames = (motion_frames + 3) // 4 # 19
|
||||||
|
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"motion_latents": motion_latents}
|
||||||
|
|
||||||
|
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
|
if s2v_pose_video is None:
|
||||||
|
return {"pose_cond": None}
|
||||||
|
pipe.load_models_to_device(["vae"])
|
||||||
|
input_video = pipe.preprocess_video(s2v_pose_video)
|
||||||
|
# get num_frames-1 frames
|
||||||
|
input_video = input_video[:, :, :num_frames]
|
||||||
|
# pad if not enough frames
|
||||||
|
padding_frames = num_frames - input_video.shape[2]
|
||||||
|
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
|
||||||
|
# encode to latents
|
||||||
|
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"pose_cond": input_latents[:,:,1:]}
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width")
|
||||||
|
tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||||
|
|
||||||
|
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames)
|
||||||
|
inputs_posi.update(audio_input_positive)
|
||||||
|
inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]})
|
||||||
|
|
||||||
|
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride))
|
||||||
|
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride))
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||||
@@ -987,6 +1081,9 @@ def model_fn_wan_video(
|
|||||||
reference_latents = None,
|
reference_latents = None,
|
||||||
vace_context = None,
|
vace_context = None,
|
||||||
vace_scale = 1.0,
|
vace_scale = 1.0,
|
||||||
|
audio_input: Optional[torch.Tensor] = None,
|
||||||
|
motion_latents: Optional[torch.Tensor] = None,
|
||||||
|
pose_cond: Optional[torch.Tensor] = None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
use_unified_sequence_parallel: bool = False,
|
use_unified_sequence_parallel: bool = False,
|
||||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||||
@@ -1024,7 +1121,21 @@ def model_fn_wan_video(
|
|||||||
tensor_names=["latents", "y"],
|
tensor_names=["latents", "y"],
|
||||||
batch_size=2 if cfg_merge else 1
|
batch_size=2 if cfg_merge else 1
|
||||||
)
|
)
|
||||||
|
# wan2.2 s2v
|
||||||
|
if audio_input is not None:
|
||||||
|
return model_fn_wans2v(
|
||||||
|
dit=dit,
|
||||||
|
latents=latents,
|
||||||
|
timestep=timestep,
|
||||||
|
context=context,
|
||||||
|
audio_input=audio_input,
|
||||||
|
motion_latents=motion_latents,
|
||||||
|
pose_cond=pose_cond,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
@@ -1143,3 +1254,104 @@ def model_fn_wan_video(
|
|||||||
f -= 1
|
f -= 1
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_wans2v(
|
||||||
|
dit,
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
audio_input,
|
||||||
|
motion_latents,
|
||||||
|
pose_cond,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_unified_sequence_parallel=False,
|
||||||
|
):
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
import torch.distributed as dist
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
origin_ref_latents = latents[:, :, 0:1]
|
||||||
|
x = latents[:, :, 1:]
|
||||||
|
|
||||||
|
# context embedding
|
||||||
|
context = dit.text_embedding(context)
|
||||||
|
|
||||||
|
# audio encode
|
||||||
|
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input)
|
||||||
|
|
||||||
|
# x and pose_cond
|
||||||
|
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||||
|
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
|
||||||
|
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
|
||||||
|
|
||||||
|
# reference image
|
||||||
|
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
|
||||||
|
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||||
|
x = torch.cat([x, ref_latents], dim=1)
|
||||||
|
# mask
|
||||||
|
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||||
|
# freqs
|
||||||
|
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
|
||||||
|
# motion
|
||||||
|
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||||
|
|
||||||
|
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
|
||||||
|
|
||||||
|
# tmod
|
||||||
|
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||||
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||||
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
|
||||||
|
|
||||||
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()
|
||||||
|
assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}"
|
||||||
|
x = torch.chunk(x, world_size, dim=1)[sp_rank]
|
||||||
|
seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())
|
||||||
|
seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]
|
||||||
|
seq_len_x = seq_len_x_list[sp_rank]
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
for block_id, block in enumerate(dit.blocks):
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||||
|
x,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
elif use_gradient_checkpointing:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||||
|
x,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||||
|
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
|
||||||
|
|
||||||
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|
||||||
|
x = x[:, :seq_len_x_global]
|
||||||
|
x = dit.head(x, t[:-1])
|
||||||
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
|
# make compatible with wan video
|
||||||
|
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||||
|
return x
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|||||||
Reference in New Issue
Block a user