mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
12 Commits
dpo-refine
...
v1.1.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afd101f345 | ||
|
|
1313f4dd63 | ||
|
|
8332ecebb7 | ||
|
|
401d7d74a5 | ||
|
|
b8d7d55568 | ||
|
|
a30ed9093f | ||
|
|
b73e713028 | ||
|
|
e0eabaa426 | ||
|
|
538017177a | ||
|
|
30292d9411 | ||
|
|
b168d7aa8b | ||
|
|
8ea45b0daa |
11
README.md
11
README.md
@@ -208,7 +208,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
@@ -235,6 +235,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -385,6 +388,12 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
## Update History
|
||||
|
||||
- **November 4, 2025**: We support [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained on Wan 2.1 and enables motion generation conditioned on reference videos.
|
||||
|
||||
- **October 30, 2025**: We support [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which enables text-to-video, image-to-video, and video continuation capabilities. This model adopts Wan's framework for both inference and training in this project.
|
||||
|
||||
- **October 27, 2025**: We support [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, further expanding Wan's ecosystem.
|
||||
|
||||
- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) is released! This model is jointly developed and open-sourced by us and the Taobao Design Team. The model is built upon Qwen-Image, specifically designed for e-commerce poster scenarios, and supports precise partition layout control. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
|
||||
|
||||
- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
|
||||
|
||||
11
README_zh.md
11
README_zh.md
@@ -208,7 +208,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
@@ -235,6 +235,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -401,6 +404,12 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
## 更新历史
|
||||
|
||||
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
|
||||
|
||||
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
|
||||
|
||||
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
|
||||
|
||||
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
|
||||
|
||||
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||
|
||||
@@ -64,6 +64,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wav2vec import WanS2VAudioEncoder
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..models.wan_video_mot import MotWanModel
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
@@ -80,6 +81,8 @@ from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
||||
@@ -153,11 +156,14 @@ model_loader_configs = [
|
||||
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5ec04e02b42d2580483ad69f4e76346a", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5f90e66a0672219f12d9a626c8c21f61", ["wan_video_dit", "wan_video_vap"], [WanModel,MotWanModel], "diffusers"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"),
|
||||
(None, "8b27900f680d7251ce44e2dc8ae1ffef", ["wan_video_dit"], [LongCatVideoTransformer3DModel], "civitai"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
|
||||
901
diffsynth/models/longcat_video_dit.py
Normal file
901
diffsynth/models/longcat_video_dit.py
Normal file
@@ -0,0 +1,901 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.amp as amp
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..vram_management import gradient_checkpoint_forward
|
||||
|
||||
|
||||
class RMSNorm_FP32(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def broadcat(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatentation"
|
||||
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
||||
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
||||
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
cp_split_hw=None
|
||||
):
|
||||
"""Rotary positional embedding for 3D
|
||||
Reference : https://blog.eleuther.ai/rotary-embeddings/
|
||||
Paper: https://arxiv.org/pdf/2104.09864.pdf
|
||||
Args:
|
||||
dim: Dimension of embedding
|
||||
base: Base value for exponential
|
||||
"""
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
|
||||
self.cp_split_hw = cp_split_hw
|
||||
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
|
||||
self.base = 10000
|
||||
self.freqs_dict = {}
|
||||
|
||||
def register_grid_size(self, grid_size):
|
||||
if grid_size not in self.freqs_dict:
|
||||
self.freqs_dict.update({
|
||||
grid_size: self.precompute_freqs_cis_3d(grid_size)
|
||||
})
|
||||
|
||||
def precompute_freqs_cis_3d(self, grid_size):
|
||||
num_frames, height, width = grid_size
|
||||
dim_t = self.head_dim - 4 * (self.head_dim // 6)
|
||||
dim_h = 2 * (self.head_dim // 6)
|
||||
dim_w = 2 * (self.head_dim // 6)
|
||||
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
|
||||
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
||||
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
||||
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
||||
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
# (T H W D)
|
||||
freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# with torch.no_grad():
|
||||
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
|
||||
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
|
||||
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||
|
||||
return freqs
|
||||
|
||||
def forward(self, q, k, grid_size):
|
||||
"""3D RoPE.
|
||||
|
||||
Args:
|
||||
query: [B, head, seq, head_dim]
|
||||
key: [B, head, seq, head_dim]
|
||||
Returns:
|
||||
query and key with the same shape as input.
|
||||
"""
|
||||
|
||||
if grid_size not in self.freqs_dict:
|
||||
self.register_grid_size(grid_size)
|
||||
|
||||
freqs_cis = self.freqs_dict[grid_size].to(q.device)
|
||||
q_, k_ = q.float(), k.float()
|
||||
freqs_cis = freqs_cis.float().to(q.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
q_ = (q_ * cos) + (rotate_half(q_) * sin)
|
||||
k_ = (k_ * cos) + (rotate_half(k_) * sin)
|
||||
|
||||
return q_.type_as(q), k_.type_as(k)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = False,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params: dict = None,
|
||||
cp_split_hw: Optional[List[int]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flashattn3 = enable_flashattn3
|
||||
self.enable_flashattn2 = enable_flashattn2
|
||||
self.enable_xformers = enable_xformers
|
||||
self.enable_bsa = enable_bsa
|
||||
self.bsa_params = bsa_params
|
||||
self.cp_split_hw = cp_split_hw
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.rope_3d = RotaryPositionalEmbedding(
|
||||
self.head_dim,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
|
||||
def _process_attn(self, q, k, v, shape):
|
||||
q = rearrange(q, "B H S D -> B S (H D)")
|
||||
k = rearrange(k, "B H S D -> B S (H D)")
|
||||
v = rearrange(v, "B H S D -> B S (H D)")
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if return_kv:
|
||||
k_cache, v_cache = k.clone(), v.clone()
|
||||
|
||||
q, k = self.rope_3d(q, k, shape)
|
||||
|
||||
# cond mode
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||
# process the condition tokens
|
||||
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
|
||||
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
|
||||
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
|
||||
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
|
||||
# process the noise tokens
|
||||
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
|
||||
x_noise = self._process_attn(q_noise, k, v, shape)
|
||||
# merge x_cond and x_noise
|
||||
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
|
||||
else:
|
||||
x = self._process_attn(q, k, v, shape)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||
x = self.proj(x)
|
||||
|
||||
if return_kv:
|
||||
return x, (k_cache, v_cache)
|
||||
else:
|
||||
return x
|
||||
|
||||
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
T, H, W = shape
|
||||
k_cache, v_cache = kv_cache
|
||||
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
|
||||
if k_cache.shape[0] == 1:
|
||||
k_cache = k_cache.repeat(B, 1, 1, 1)
|
||||
v_cache = v_cache.repeat(B, 1, 1, 1)
|
||||
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
k_full = torch.cat([k_cache, k], dim=2).contiguous()
|
||||
v_full = torch.cat([v_cache, v], dim=2).contiguous()
|
||||
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
|
||||
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
|
||||
q = q_padding[:, :, -N:].contiguous()
|
||||
|
||||
x = self._process_attn(q, k_full, v_full, shape)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
enable_flashattn3=False,
|
||||
enable_flashattn2=False,
|
||||
enable_xformers=False,
|
||||
):
|
||||
super(MultiHeadCrossAttention, self).__init__()
|
||||
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = nn.Linear(dim, dim)
|
||||
self.kv_linear = nn.Linear(dim, dim * 2)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
|
||||
self.enable_flashattn3 = enable_flashattn3
|
||||
self.enable_flashattn2 = enable_flashattn2
|
||||
self.enable_xformers = enable_xformers
|
||||
|
||||
def _process_cross_attn(self, x, cond, kv_seqlen):
|
||||
B, N, C = x.shape
|
||||
assert C == self.dim and cond.shape[2] == self.dim
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
q = rearrange(q, "B S H D -> B S (H D)")
|
||||
k = rearrange(k, "B S H D -> B S (H D)")
|
||||
v = rearrange(v, "B S H D -> B S (H D)")
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
|
||||
x = x.view(B, -1, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
|
||||
"""
|
||||
x: [B, N, C]
|
||||
cond: [B, M, C]
|
||||
"""
|
||||
if num_cond_latents is None or num_cond_latents == 0:
|
||||
return self._process_cross_attn(x, cond, kv_seqlen)
|
||||
else:
|
||||
B, N, C = x.shape
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
assert shape is not None, "SHOULD pass in the shape"
|
||||
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
|
||||
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
|
||||
output = torch.cat([
|
||||
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
|
||||
output_noise
|
||||
], dim=1).contiguous()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class LayerNorm_FP32(nn.LayerNorm):
|
||||
def __init__(self, dim, eps, elementwise_affine):
|
||||
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
origin_dtype = inputs.dtype
|
||||
out = F.layer_norm(
|
||||
inputs.float(),
|
||||
self.normalized_shape,
|
||||
None if self.weight is None else self.weight.float(),
|
||||
None if self.bias is None else self.bias.float() ,
|
||||
self.eps
|
||||
).to(origin_dtype)
|
||||
return out
|
||||
|
||||
|
||||
def modulate_fp32(norm_func, x, shift, scale):
|
||||
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
|
||||
# ensure the modulation params be fp32
|
||||
assert shift.dtype == torch.float32, scale.dtype == torch.float32
|
||||
dtype = x.dtype
|
||||
x = norm_func(x.to(torch.float32))
|
||||
x = x * (scale + 1) + shift
|
||||
x = x.to(dtype)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer_FP32(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_patch = num_patch
|
||||
self.out_channels = out_channels
|
||||
self.adaln_tembed_dim = adaln_tembed_dim
|
||||
|
||||
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, t, latent_shape):
|
||||
# timestep shape: [B, T, C]
|
||||
assert t.dtype == torch.float32
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast('cuda', dtype=torch.float32):
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.dim = dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, t_embed_dim, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.t_embed_dim = t_embed_dim
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||
freqs = freqs.to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
if t_freq.dtype != dtype:
|
||||
t_freq = t_freq.to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.y_proj = nn.Sequential(
|
||||
nn.Linear(in_channels, hidden_size, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
B, _, N, C = caption.shape
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class PatchEmbed3D(nn.Module):
|
||||
"""Video to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: (2,4,4).
|
||||
in_chans (int): Number of input video channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 4, 4),
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.flatten = flatten
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, D, H, W = x.size()
|
||||
if W % self.patch_size[2] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||
if H % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||
if D % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||
|
||||
B, C, T, H, W = x.shape
|
||||
x = self.proj(x) # (B C T H W)
|
||||
if self.norm is not None:
|
||||
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
class LongCatSingleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: int,
|
||||
adaln_tembed_dim: int,
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = False,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params=None,
|
||||
cp_split_hw=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# scale and gate modulation
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
|
||||
|
||||
self.attn = Attention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
enable_bsa=enable_bsa,
|
||||
bsa_params=bsa_params,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
)
|
||||
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
|
||||
|
||||
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
|
||||
"""
|
||||
x: [B, N, C]
|
||||
y: [1, N_valid_tokens, C]
|
||||
t: [B, T, C_t]
|
||||
y_seqlen: [B]; type of a list
|
||||
latent_shape: latent shape of a single item
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
|
||||
# self attn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
|
||||
|
||||
if kv_cache is not None:
|
||||
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
|
||||
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
|
||||
else:
|
||||
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
|
||||
|
||||
if return_kv:
|
||||
x_s, kv_cache = attn_outputs
|
||||
else:
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
# cross attn
|
||||
if not skip_crs_attn:
|
||||
if kv_cache is not None:
|
||||
num_cond_latents = None
|
||||
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
|
||||
|
||||
# ffn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
if return_kv:
|
||||
return x, kv_cache
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
hidden_size: int = 4096,
|
||||
depth: int = 48,
|
||||
num_heads: int = 32,
|
||||
caption_channels: int = 4096,
|
||||
mlp_ratio: int = 4,
|
||||
adaln_tembed_dim: int = 512,
|
||||
frequency_embedding_size: int = 256,
|
||||
# default params
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
# attention config
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = True,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
|
||||
cp_split_hw: Optional[List[int]] = [1, 1],
|
||||
text_tokens_zero_pad: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.cp_split_hw = cp_split_hw
|
||||
|
||||
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels,
|
||||
hidden_size=hidden_size,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
LongCatSingleStreamBlock(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
adaln_tembed_dim=adaln_tembed_dim,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
enable_bsa=enable_bsa,
|
||||
bsa_params=bsa_params,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer_FP32(
|
||||
hidden_size,
|
||||
np.prod(self.patch_size),
|
||||
out_channels,
|
||||
adaln_tembed_dim,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.text_tokens_zero_pad = text_tokens_zero_pad
|
||||
|
||||
self.lora_dict = {}
|
||||
self.active_loras = []
|
||||
|
||||
def enable_loras(self, lora_key_list=[]):
|
||||
self.disable_all_loras()
|
||||
|
||||
module_loras = {} # {module_name: [lora1, lora2, ...]}
|
||||
model_device = next(self.parameters()).device
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
|
||||
for lora_key in lora_key_list:
|
||||
if lora_key in self.lora_dict:
|
||||
for lora in self.lora_dict[lora_key].loras:
|
||||
lora.to(model_device, dtype=model_dtype, non_blocking=True)
|
||||
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
|
||||
if module_name not in module_loras:
|
||||
module_loras[module_name] = []
|
||||
module_loras[module_name].append(lora)
|
||||
self.active_loras.append(lora_key)
|
||||
|
||||
for module_name, loras in module_loras.items():
|
||||
module = self._get_module_by_name(module_name)
|
||||
if not hasattr(module, 'org_forward'):
|
||||
module.org_forward = module.forward
|
||||
module.forward = self._create_multi_lora_forward(module, loras)
|
||||
|
||||
def _create_multi_lora_forward(self, module, loras):
|
||||
def multi_lora_forward(x, *args, **kwargs):
|
||||
weight_dtype = x.dtype
|
||||
org_output = module.org_forward(x, *args, **kwargs)
|
||||
|
||||
total_lora_output = 0
|
||||
for lora in loras:
|
||||
if lora.use_lora:
|
||||
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
|
||||
lx = lora.lora_up(lx)
|
||||
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
|
||||
total_lora_output += lora_output
|
||||
|
||||
return org_output + total_lora_output
|
||||
|
||||
return multi_lora_forward
|
||||
|
||||
def _get_module_by_name(self, module_name):
|
||||
try:
|
||||
module = self
|
||||
for part in module_name.split('.'):
|
||||
module = getattr(module, part)
|
||||
return module
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
|
||||
|
||||
def disable_all_loras(self):
|
||||
for name, module in self.named_modules():
|
||||
if hasattr(module, 'org_forward'):
|
||||
module.forward = module.org_forward
|
||||
delattr(module, 'org_forward')
|
||||
|
||||
for lora_key, lora_network in self.lora_dict.items():
|
||||
for lora in lora_network.loras:
|
||||
lora.to("cpu")
|
||||
|
||||
self.active_loras.clear()
|
||||
|
||||
def enable_bsa(self,):
|
||||
for block in self.blocks:
|
||||
block.attn.enable_bsa = True
|
||||
|
||||
def disable_bsa(self,):
|
||||
for block in self.blocks:
|
||||
block.attn.enable_bsa = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
num_cond_latents=0,
|
||||
return_kv=False,
|
||||
kv_cache_dict={},
|
||||
skip_crs_attn=False,
|
||||
offload_kv_cache=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
|
||||
B, _, T, H, W = hidden_states.shape
|
||||
|
||||
N_t = T // self.patch_size[0]
|
||||
N_h = H // self.patch_size[1]
|
||||
N_w = W // self.patch_size[2]
|
||||
|
||||
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
|
||||
|
||||
# expand the shape of timestep from [B] to [B, T]
|
||||
if len(timestep.shape) == 1:
|
||||
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
|
||||
timestep[:, :num_cond_latents] = 0
|
||||
|
||||
dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
|
||||
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
|
||||
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
|
||||
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
|
||||
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
|
||||
else:
|
||||
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
|
||||
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
|
||||
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
|
||||
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
|
||||
|
||||
# blocks
|
||||
kv_cache_dict_ret = {}
|
||||
for i, block in enumerate(self.blocks):
|
||||
block_outputs = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=hidden_states,
|
||||
y=encoder_hidden_states,
|
||||
t=t,
|
||||
y_seqlen=y_seqlens,
|
||||
latent_shape=(N_t, N_h, N_w),
|
||||
num_cond_latents=num_cond_latents,
|
||||
return_kv=return_kv,
|
||||
kv_cache=kv_cache_dict.get(i, None),
|
||||
skip_crs_attn=skip_crs_attn,
|
||||
)
|
||||
|
||||
if return_kv:
|
||||
hidden_states, kv_cache = block_outputs
|
||||
if offload_kv_cache:
|
||||
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
|
||||
else:
|
||||
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
|
||||
else:
|
||||
hidden_states = block_outputs
|
||||
|
||||
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
|
||||
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
|
||||
|
||||
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
if return_kv:
|
||||
return hidden_states, kv_cache_dict_ret
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unpatchify(self, x, N_t, N_h, N_w):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): of shape [B, N, C]
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): of shape [B, C_out, T, H, W]
|
||||
"""
|
||||
T_p, H_p, W_p = self.patch_size
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
|
||||
N_t=N_t,
|
||||
N_h=N_h,
|
||||
N_w=N_w,
|
||||
T_p=T_p,
|
||||
H_p=H_p,
|
||||
W_p=W_p,
|
||||
C_out=self.out_channels,
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return LongCatVideoTransformer3DModelDictConverter()
|
||||
|
||||
|
||||
class LongCatVideoTransformer3DModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@@ -362,7 +362,7 @@ class WanModel(torch.nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@@ -437,6 +437,11 @@ class WanModelStateDictConverter:
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
@@ -454,6 +459,14 @@ class WanModelStateDictConverter:
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
@@ -470,7 +483,7 @@ class WanModelStateDictConverter:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
if hash_state_dict_keys(state_dict_) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
@@ -488,6 +501,20 @@ class WanModelStateDictConverter:
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict_) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
@@ -495,6 +522,12 @@ class WanModelStateDictConverter:
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||
state_dict = {name: param for name, param in state_dict.items() if name.split(".")[0] not in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
|
||||
281
diffsynth/models/wan_video_mot.py
Normal file
281
diffsynth/models/wan_video_mot.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP
|
||||
from .utils import hash_state_dict_keys
|
||||
import einops
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MotSelfAttention(SelfAttention):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__(dim, num_heads, eps)
|
||||
def forward(self, x, freqs, is_before_attn=False):
|
||||
if is_before_attn:
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
return q, k, v
|
||||
else:
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class MotWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
||||
self.block_id = block_id
|
||||
|
||||
self.self_attn = MotSelfAttention(dim, num_heads, eps)
|
||||
|
||||
|
||||
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot):
|
||||
|
||||
# 1. prepare scale parameter
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
|
||||
scale_params_mot_ref = self.modulation + t_mod_mot.float()
|
||||
scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1)
|
||||
shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2)
|
||||
|
||||
# 2. Self-attention
|
||||
input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa)
|
||||
# original block self-attn
|
||||
attn1 = wan_block.self_attn
|
||||
q = attn1.norm_q(attn1.q(input_x))
|
||||
k = attn1.norm_k(attn1.k(input_x))
|
||||
v = attn1.v(input_x)
|
||||
q = rope_apply(q, freqs, attn1.num_heads)
|
||||
k = rope_apply(k, freqs, attn1.num_heads)
|
||||
|
||||
# mot block self-attn
|
||||
norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
||||
norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot)
|
||||
norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1)
|
||||
q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True)
|
||||
|
||||
tmp_hidden_states = flash_attention(
|
||||
torch.cat([q, q_mot], dim=-2),
|
||||
torch.cat([k, k_mot], dim=-2),
|
||||
torch.cat([v, v_mot], dim=-2),
|
||||
num_heads=attn1.num_heads)
|
||||
|
||||
attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2)
|
||||
|
||||
attn_output = attn1.o(attn_output)
|
||||
x = wan_block.gate(x, gate_msa, attn_output)
|
||||
|
||||
attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False)
|
||||
# gate
|
||||
attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1)
|
||||
attn_output_mot = attn_output_mot * gate_msa_mot_ref
|
||||
attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1)
|
||||
x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot)
|
||||
|
||||
# 3. cross-attention and feed-forward
|
||||
x = x + wan_block.cross_attn(wan_block.norm3(x), context)
|
||||
input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp)
|
||||
x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x))
|
||||
|
||||
x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot)
|
||||
# modulate
|
||||
norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
||||
norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot)
|
||||
norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1)
|
||||
input_x_mot = self.ffn(norm_x_mot_ref)
|
||||
# gate
|
||||
input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1)
|
||||
input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref
|
||||
input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1)
|
||||
x_mot = (x_mot.float() + input_x_mot).type_as(x_mot)
|
||||
|
||||
return x, x_mot
|
||||
|
||||
|
||||
class MotWanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
||||
patch_size=(1, 2, 2),
|
||||
has_image_input=True,
|
||||
has_image_pos_emb=False,
|
||||
dim=5120,
|
||||
num_heads=40,
|
||||
ffn_dim=13824,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
in_dim=36,
|
||||
eps=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.mot_layers = mot_layers
|
||||
self.freq_dim = freq_dim
|
||||
self.dim = dim
|
||||
|
||||
self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)}
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)
|
||||
|
||||
# mot blocks
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
||||
for i in self.mot_layers
|
||||
])
|
||||
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
return x
|
||||
|
||||
def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0):
|
||||
def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0):
|
||||
# 1d rope precompute
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
||||
[: (dim // 2)].double() / dim))
|
||||
freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta)
|
||||
h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
||||
w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
||||
|
||||
freqs = torch.cat([
|
||||
f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1)
|
||||
return freqs
|
||||
|
||||
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id):
|
||||
block = self.blocks[self.mot_layers_mapping[block_id]]
|
||||
x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)
|
||||
return x, x_mot
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return MotWanModelDictConverter()
|
||||
|
||||
|
||||
class MotWanModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name}
|
||||
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
else:
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
|
||||
|
||||
state_dict_ = {}
|
||||
|
||||
for name, param in state_dict.items():
|
||||
name = name.replace("_mot_ref", "")
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
if name.split(".")[1].isdigit():
|
||||
block_id = int(name.split(".")[1])
|
||||
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
|
||||
if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B
|
||||
config = {
|
||||
"mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"num_heads": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..models.wan_video_mot import MotWanModel
|
||||
from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
@@ -46,9 +48,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.vace2: VaceWanModel = None
|
||||
self.vap: MotWanModel = None
|
||||
self.animate_adapter: WanAnimateAdapter = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
@@ -68,9 +71,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoPostUnit_AnimatePoseLatents(),
|
||||
WanVideoPostUnit_AnimateFacePixelValues(),
|
||||
WanVideoPostUnit_AnimateInpaint(),
|
||||
WanVideoUnit_VAP(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
WanVideoUnit_LongCatVideo(),
|
||||
]
|
||||
self.post_units = [
|
||||
WanVideoPostUnit_S2V(),
|
||||
@@ -150,6 +155,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit is not None:
|
||||
from ..models.longcat_video_dit import LayerNorm_FP32, RMSNorm_FP32
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
@@ -162,6 +168,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Conv1d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
LayerNorm_FP32: AutoWrappedModule,
|
||||
RMSNorm_FP32: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -387,6 +395,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
vace = model_manager.fetch_model("wan_video_vace", index=2)
|
||||
pipe.vap = model_manager.fetch_model("wan_video_vap")
|
||||
if isinstance(vace, list):
|
||||
pipe.vace, pipe.vace2 = vace
|
||||
else:
|
||||
@@ -450,6 +459,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
animate_face_video: Optional[list[Image.Image]] = None,
|
||||
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
||||
animate_mask_video: Optional[list[Image.Image]] = None,
|
||||
# VAP
|
||||
vap_video: Optional[list[Image.Image]] = None,
|
||||
vap_prompt: Optional[str] = " ",
|
||||
negative_vap_prompt: Optional[str] = " ",
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
@@ -467,6 +480,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
# Speed control
|
||||
motion_bucket_id: Optional[int] = None,
|
||||
# LongCat-Video
|
||||
longcat_video: Optional[list[Image.Image]] = None,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||
@@ -486,10 +501,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"vap_prompt": vap_prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
"negative_vap_prompt": negative_vap_prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_shared = {
|
||||
@@ -504,10 +521,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"sigma_shift": sigma_shift,
|
||||
"motion_bucket_id": motion_bucket_id,
|
||||
"longcat_video": longcat_video,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||
"vap_video": vap_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -919,6 +938,71 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
else:
|
||||
return {"vace_context": None, "vace_scale": vace_scale}
|
||||
|
||||
class WanVideoUnit_VAP(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder", "vae", "image_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("vap_video") is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
else:
|
||||
# 1. encode vap prompt
|
||||
pipe.load_models_to_device(["text_encoder"])
|
||||
vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "")
|
||||
vap_prompt_emb = pipe.prompter.encode_prompt(vap_prompt, positive=inputs_posi.get('positive',None), device=pipe.device)
|
||||
negative_vap_prompt_emb = pipe.prompter.encode_prompt(negative_vap_prompt, positive=inputs_nega.get('positive',None), device=pipe.device)
|
||||
inputs_posi.update({"context_vap":vap_prompt_emb})
|
||||
inputs_nega.update({"context_vap":negative_vap_prompt_emb})
|
||||
# 2. prepare vap image clip embedding
|
||||
pipe.load_models_to_device(["vae", "image_encoder"])
|
||||
vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image")
|
||||
|
||||
num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1)
|
||||
|
||||
image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device)
|
||||
|
||||
vap_clip_context = pipe.image_encoder.encode_image([image_vap])
|
||||
if end_image is not None:
|
||||
vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
|
||||
if pipe.dit.has_image_pos_emb:
|
||||
vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1)
|
||||
vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inputs_shared.update({"vap_clip_feature":vap_clip_context})
|
||||
|
||||
# 3. prepare vap latents
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
msk[:, -1:] = 1
|
||||
last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1)
|
||||
else:
|
||||
vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
vap_video = pipe.preprocess_video(vap_video)
|
||||
vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inputs_shared.update({"vap_hidden_state":vap_latent})
|
||||
pipe.load_models_to_device([])
|
||||
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||
@@ -1028,8 +1112,8 @@ class WanVideoUnit_S2V(PipelineUnit):
|
||||
if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate")
|
||||
s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video")
|
||||
input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000)
|
||||
s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None)
|
||||
|
||||
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds)
|
||||
inputs_posi.update(audio_input_positive)
|
||||
@@ -1151,6 +1235,22 @@ class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
|
||||
return {"y": y}
|
||||
|
||||
|
||||
class WanVideoUnit_LongCatVideo(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("longcat_video",),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, longcat_video):
|
||||
if longcat_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
longcat_video = pipe.preprocess_video(longcat_video)
|
||||
longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"longcat_latents": longcat_latents}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -1261,6 +1361,7 @@ def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
vap: MotWanModel = None,
|
||||
animate_adapter: WanAnimateAdapter = None,
|
||||
latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
@@ -1273,12 +1374,16 @@ def model_fn_wan_video(
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
vap_hidden_state = None,
|
||||
vap_clip_feature = None,
|
||||
context_vap = None,
|
||||
drop_motion_frames: bool = True,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
pose_latents=None,
|
||||
face_pixel_values=None,
|
||||
longcat_latents=None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
@@ -1313,6 +1418,18 @@ def model_fn_wan_video(
|
||||
tensor_names=["latents", "y"],
|
||||
batch_size=2 if cfg_merge else 1
|
||||
)
|
||||
# LongCat-Video
|
||||
if isinstance(dit, LongCatVideoTransformer3DModel):
|
||||
return model_fn_longcat_video(
|
||||
dit=dit,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
longcat_latents=longcat_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
# wan2.2 s2v
|
||||
if audio_embeds is not None:
|
||||
return model_fn_wans2v(
|
||||
@@ -1369,7 +1486,7 @@ def model_fn_wan_video(
|
||||
if clip_feature is not None and dit.require_clip_embedding:
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
|
||||
# Camera control
|
||||
x = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
@@ -1394,6 +1511,25 @@ def model_fn_wan_video(
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# VAP
|
||||
if vap is not None:
|
||||
# hidden state
|
||||
x_vap = vap_hidden_state
|
||||
x_vap = vap.patchify(x_vap)
|
||||
x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous()
|
||||
# Timestep
|
||||
clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype)
|
||||
t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep))
|
||||
t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim))
|
||||
|
||||
# rope
|
||||
freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device)
|
||||
|
||||
# context
|
||||
vap_clip_embedding = vap.img_emb(vap_clip_feature)
|
||||
context_vap = vap.text_embedding(context_vap)
|
||||
context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
@@ -1423,23 +1559,45 @@ def model_fn_wan_video(
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
def create_custom_forward_vap(block, vap):
|
||||
def custom_forward(*inputs):
|
||||
return vap(block, *inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
# Block
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
if vap is not None and block_id in vap.mot_layers_mapping:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
|
||||
else:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
# VACE
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
@@ -1468,6 +1626,36 @@ def model_fn_wan_video(
|
||||
return x
|
||||
|
||||
|
||||
def model_fn_longcat_video(
|
||||
dit: LongCatVideoTransformer3DModel,
|
||||
latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
longcat_latents: torch.Tensor = None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
if longcat_latents is not None:
|
||||
latents[:, :, :longcat_latents.shape[2]] = longcat_latents
|
||||
num_cond_latents = longcat_latents.shape[2]
|
||||
else:
|
||||
num_cond_latents = 0
|
||||
context = context.unsqueeze(0)
|
||||
encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64)
|
||||
output = dit(
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
encoder_attention_mask,
|
||||
num_cond_latents=num_cond_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
output = -output
|
||||
output = output.to(latents.dtype)
|
||||
return output
|
||||
|
||||
|
||||
def model_fn_wans2v(
|
||||
dit,
|
||||
latents,
|
||||
|
||||
@@ -225,6 +225,13 @@ class ToAbsolutePath(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return os.path.join(self.base_path, data)
|
||||
|
||||
class LoadAudio(DataProcessingOperator):
|
||||
def __init__(self, sr=16000):
|
||||
self.sr = sr
|
||||
def __call__(self, data: str):
|
||||
import librosa
|
||||
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||
return input_audio
|
||||
|
||||
|
||||
class UnifiedDataset(torch.utils.data.Dataset):
|
||||
|
||||
@@ -475,64 +475,6 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(pipe, lora_base_model, model)
|
||||
|
||||
def disable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(False)
|
||||
|
||||
def enable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(True)
|
||||
|
||||
|
||||
class DPOLoss:
|
||||
def __init__(self, beta=2500):
|
||||
self.beta = beta
|
||||
|
||||
def sample_timestep(self, pipe):
|
||||
timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return timestep
|
||||
|
||||
def training_loss_minimum(self, pipe, noise, timestep, **inputs):
|
||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
noise_pred = pipe.model_fn(**inputs, timestep=timestep)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
def loss(self, model, data):
|
||||
# Loss DPO: -logσ(−β(diff_policy − diff_ref))
|
||||
# Prepare inputs
|
||||
win_data = {key: data[key] for key in ["prompt", "image"]}
|
||||
lose_data = {"prompt": data["prompt"], "image": data["lose_image"]}
|
||||
inputs_win = model.forward_preprocess(win_data)
|
||||
inputs_lose = model.forward_preprocess(lose_data)
|
||||
inputs_win.pop('noise')
|
||||
inputs_lose.pop('noise')
|
||||
models = {name: getattr(model.pipe, name) for name in model.pipe.in_iteration_models}
|
||||
# sample timestep and noise
|
||||
timestep = self.sample_timestep(model.pipe)
|
||||
noise = torch.rand_like(inputs_win["latents"])
|
||||
# compute diff_policy = loss_win - loss_lose
|
||||
loss_win = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
||||
loss_lose = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
||||
diff_policy = loss_win - loss_lose
|
||||
# compute diff_ref
|
||||
# TODO: may support full model training
|
||||
model.disable_all_lora_layers(model.pipe.dit)
|
||||
# load the original model weights
|
||||
with torch.no_grad():
|
||||
loss_win_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
||||
loss_lose_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
||||
diff_ref = loss_win_ref - loss_lose_ref
|
||||
model.enable_all_lora_layers(model.pipe.dit)
|
||||
# compute loss
|
||||
loss = -1. * torch.nn.functional.logsigmoid(self.beta * (diff_ref - diff_policy)).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
@@ -661,6 +603,7 @@ def wan_parser():
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--audio_processor_config", type=str, default=None, help="Model ID with origin paths to the audio processor config, e.g., Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
|
||||
@@ -4,6 +4,7 @@ from PIL import Image
|
||||
from einops import repeat, reduce
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
from modelscope import snapshot_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -196,13 +197,24 @@ class ModelConfig:
|
||||
self.local_model_path = "./models"
|
||||
if not skip_download:
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
if self.download_resource.lower() == "modelscope":
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
elif self.download_resource.lower() == "huggingface":
|
||||
hf_snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_patterns=allow_file_pattern,
|
||||
ignore_patterns=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
else:
|
||||
raise ValueError("`download_resource` should be `modelscope` or `huggingface`.")
|
||||
|
||||
# Let rank 1, 2, ... wait for rank 0
|
||||
if use_usp:
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task, DPOLoss
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
@@ -84,29 +84,24 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
|
||||
|
||||
def forward(self, data, inputs=None, return_inputs=False):
|
||||
# DPO (DPO requires a special training loss)
|
||||
if self.task == "dpo":
|
||||
loss = DPOLoss().loss(self, data)
|
||||
return loss
|
||||
# Inputs
|
||||
if inputs is None:
|
||||
inputs = self.forward_preprocess(data)
|
||||
else:
|
||||
# Inputs
|
||||
if inputs is None:
|
||||
inputs = self.forward_preprocess(data)
|
||||
else:
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
if return_inputs: return inputs
|
||||
|
||||
# Loss
|
||||
if self.task == "sft":
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
elif self.task == "data_process":
|
||||
loss = inputs
|
||||
elif self.task == "direct_distill":
|
||||
loss = self.pipe.direct_distill_loss(**inputs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||
return loss
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
if return_inputs: return inputs
|
||||
|
||||
# Loss
|
||||
if self.task == "sft":
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
elif self.task == "data_process":
|
||||
loss = inputs
|
||||
elif self.task == "direct_distill":
|
||||
loss = self.pipe.direct_distill_loss(**inputs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@@ -148,6 +143,5 @@ if __name__ == "__main__":
|
||||
"sft": launch_training_task,
|
||||
"data_process": launch_data_process_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"dpo": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||
|
||||
@@ -49,7 +49,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./model_training/full/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./model_training/lora/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
@@ -76,7 +76,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./model_inference/krea-realtime-video.py)|[code](./model_training/full/krea-realtime-video.sh)|[code](./model_training/validate_full/krea-realtime-video.py)|[code](./model_training/lora/krea-realtime-video.sh)|[code](./model_training/validate_lora/krea-realtime-video.py)|
|
||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./model_inference/LongCat-Video.py)|[code](./model_training/full/LongCat-Video.sh)|[code](./model_training/validate_full/LongCat-Video.py)|[code](./model_training/lora/LongCat-Video.sh)|[code](./model_training/validate_lora/LongCat-Video.py)|
|
||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./model_training/full/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./model_training/lora/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
@@ -76,6 +76,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./model_inference/krea-realtime-video.py)|[code](./model_training/full/krea-realtime-video.sh)|[code](./model_training/validate_full/krea-realtime-video.py)|[code](./model_training/lora/krea-realtime-video.sh)|[code](./model_training/validate_lora/krea-realtime-video.py)|
|
||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./model_inference/LongCat-Video.py)|[code](./model_training/full/LongCat-Video.sh)|[code](./model_training/validate_full/LongCat-Video.py)|[code](./model_training/lora/LongCat-Video.sh)|[code](./model_training/validate_lora/LongCat-Video.py)|
|
||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
|
||||
35
examples/wanvideo/model_inference/LongCat-Video.py
Normal file
35
examples/wanvideo/model_inference/LongCat-Video.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.",
|
||||
negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
||||
seed=0, tiled=True, num_frames=93,
|
||||
cfg_scale=2, sigma_shift=1,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Video-continuation (The number of frames in `longcat_video` should be 4n+1.)
|
||||
longcat_video = video[-17:]
|
||||
video = pipe(
|
||||
prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.",
|
||||
negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
||||
seed=1, tiled=True, num_frames=93,
|
||||
cfg_scale=2, sigma_shift=1,
|
||||
longcat_video=longcat_video,
|
||||
)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
from typing import List
|
||||
|
||||
def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]:
|
||||
if len(video_frames) == 0:
|
||||
return []
|
||||
if mode == "first":
|
||||
return video_frames[:num]
|
||||
if mode == "evenly":
|
||||
import torch as _torch
|
||||
idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist()
|
||||
return [video_frames[i] for i in idx]
|
||||
if mode == "random":
|
||||
if len(video_frames) <= num:
|
||||
return video_frames
|
||||
import random as _random
|
||||
start = _random.randint(0, len(video_frames) - num)
|
||||
return video_frames[start:start+num]
|
||||
return video_frames
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="wanvap/*", local_dir="data/example_video_dataset")
|
||||
ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'
|
||||
target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'
|
||||
|
||||
|
||||
image = Image.open(target_image_path).convert("RGB")
|
||||
ref_video = VideoData(ref_video_path, height=480, width=832)
|
||||
ref_frames = select_frames(ref_video, num=49, mode="evenly")
|
||||
|
||||
vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery."
|
||||
prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
input_image=image,
|
||||
seed=42, tiled=True,
|
||||
height=480, width=832,
|
||||
num_frames=49,
|
||||
vap_video=ref_frames,
|
||||
vap_prompt=vap_prompt,
|
||||
negative_vap_prompt=negative_prompt,
|
||||
)
|
||||
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
25
examples/wanvideo/model_inference/krea-realtime-video.py
Normal file
25
examples/wanvideo/model_inference/krea-realtime-video.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="a cat sitting on a boat",
|
||||
num_inference_steps=6, num_frames=81,
|
||||
seed=0, tiled=True,
|
||||
cfg_scale=1,
|
||||
sigma_shift=20,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
12
examples/wanvideo/model_training/full/LongCat-Video.sh
Normal file
12
examples/wanvideo/model_training/full/LongCat-Video.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LongCat-Video_full" \
|
||||
--trainable_models "dit"
|
||||
@@ -0,0 +1,16 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_vap.csv \
|
||||
--data_file_keys "video,vap_video" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.vap." \
|
||||
--output_path "./models/train/Video-As-Prompt-Wan2.1-14B_full" \
|
||||
--trainable_models "vap" \
|
||||
--extra_inputs "vap_video,input_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
17
examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh
Normal file
17
examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/wans2v \
|
||||
--dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \
|
||||
--data_file_keys "video,input_audio,s2v_pose_video" \
|
||||
--height 448 \
|
||||
--width 832 \
|
||||
--num_frames 81 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \
|
||||
--audio_processor_config "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--trainable_models "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-S2V-14B_full" \
|
||||
--extra_inputs "input_image,input_audio,s2v_pose_video" \
|
||||
--use_gradient_checkpointing_offload
|
||||
12
examples/wanvideo/model_training/full/krea-realtime-video.sh
Normal file
12
examples/wanvideo/model_training/full/krea-realtime-video.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "krea/krea-realtime-video:krea-realtime-video-14b.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/krea-realtime-video_full" \
|
||||
--trainable_models "dit"
|
||||
14
examples/wanvideo/model_training/lora/LongCat-Video.sh
Normal file
14
examples/wanvideo/model_training/lora/LongCat-Video.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LongCat-Video_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "adaLN_modulation.1,attn.qkv,attn.proj,cross_attn.q_linear,cross_attn.kv_linear,cross_attn.proj,ffn.w1,ffn.w2,ffn.w3" \
|
||||
--lora_rank 32
|
||||
@@ -0,0 +1,18 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_vap.csv \
|
||||
--data_file_keys "video,vap_video" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 10 \
|
||||
--model_id_with_origin_paths "ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Video-As-Prompt-Wan2.1-14B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "vap_video,input_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
19
examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh
Normal file
19
examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/wans2v \
|
||||
--dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \
|
||||
--data_file_keys "video,input_audio,s2v_pose_video" \
|
||||
--height 448 \
|
||||
--width 832 \
|
||||
--num_frames 81 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \
|
||||
--audio_processor_config "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-S2V-14B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,input_audio,s2v_pose_video" \
|
||||
--use_gradient_checkpointing_offload
|
||||
14
examples/wanvideo/model_training/lora/krea-realtime-video.sh
Normal file
14
examples/wanvideo/model_training/lora/krea-realtime-video.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "krea/krea-realtime-video:krea-realtime-video-14b.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/krea-realtime-video_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32
|
||||
@@ -2,7 +2,7 @@ import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, ImageCropAndResize, ToAbsolutePath
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
class WanTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
model_paths=None, model_id_with_origin_paths=None, audio_processor_config=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None,
|
||||
use_gradient_checkpointing=True,
|
||||
@@ -22,7 +22,9 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
|
||||
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
if audio_processor_config is not None:
|
||||
audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1])
|
||||
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, audio_processor_config=audio_processor_config)
|
||||
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
@@ -109,12 +111,14 @@ if __name__ == "__main__":
|
||||
time_division_remainder=1,
|
||||
),
|
||||
special_operator_map={
|
||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16))
|
||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
|
||||
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),
|
||||
}
|
||||
)
|
||||
model = WanTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
audio_processor_config=args.audio_processor_config,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/LongCat-Video_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_LongCat-Video.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Video-As-Prompt-Wan2.1-14B_full/epoch-1.safetensors")
|
||||
pipe.vap.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'
|
||||
target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'
|
||||
|
||||
image = Image.open(target_image_path).convert("RGB")
|
||||
ref_video = VideoData(ref_video_path, height=480, width=832)
|
||||
ref_frames = [ref_video[i] for i in range(49)]
|
||||
|
||||
vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery."
|
||||
prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
input_image=image,
|
||||
seed=42, tiled=True,
|
||||
height=480, width=832,
|
||||
num_frames=49,
|
||||
vap_video=ref_frames,
|
||||
vap_prompt=vap_prompt,
|
||||
negative_vap_prompt=negative_prompt,
|
||||
)
|
||||
save_video(video, "video_Video-As-Prompt-Wan2.1-14B.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import librosa
|
||||
from diffsynth import VideoData, save_video_with_audio, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
|
||||
)
|
||||
|
||||
state_dict = load_state_dict("models/train/Wan2.2-S2V-14B_full/epoch-0.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict, strict=False)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
|
||||
num_frames = 81 # 4n+1
|
||||
height = 448
|
||||
width = 832
|
||||
|
||||
prompt = "a person is singing"
|
||||
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
|
||||
# s2v audio input, recommend 16kHz sampling rate
|
||||
audio_path = 'data/example_video_dataset/wans2v/sing.MP3'
|
||||
input_audio, sample_rate = librosa.load(audio_path, sr=16000)
|
||||
# S2V pose video input
|
||||
pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'
|
||||
pose_video = VideoData(pose_video_path, height=height, width=width)
|
||||
|
||||
# Speech-to-video with pose
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_sample_rate=sample_rate,
|
||||
input_audio=input_audio,
|
||||
s2v_pose_video=pose_video,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)
|
||||
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/krea-realtime-video_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="a cat sitting on a boat",
|
||||
num_inference_steps=6, num_frames=81,
|
||||
seed=0, tiled=True,
|
||||
cfg_scale=1,
|
||||
sigma_shift=20,
|
||||
)
|
||||
save_video(video, "output.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/LongCat-Video_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_LongCat-Video.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Video-As-Prompt-Wan2.1-14B_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'
|
||||
target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'
|
||||
|
||||
image = Image.open(target_image_path).convert("RGB")
|
||||
ref_video = VideoData(ref_video_path, height=480, width=832)
|
||||
ref_frames = [ref_video[i] for i in range(49)]
|
||||
|
||||
vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery."
|
||||
prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
input_image=image,
|
||||
seed=42, tiled=True,
|
||||
height=480, width=832,
|
||||
num_frames=49,
|
||||
vap_video=ref_frames,
|
||||
vap_prompt=vap_prompt,
|
||||
negative_vap_prompt=negative_prompt,
|
||||
)
|
||||
save_video(video, "video_Video-As-Prompt-Wan2.1-14B.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import librosa
|
||||
from diffsynth import VideoData, save_video_with_audio
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
|
||||
)
|
||||
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.2-S2V-14B_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
|
||||
num_frames = 81 # 4n+1
|
||||
height = 448
|
||||
width = 832
|
||||
|
||||
prompt = "a person is singing"
|
||||
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
|
||||
# s2v audio input, recommend 16kHz sampling rate
|
||||
audio_path = 'data/example_video_dataset/wans2v/sing.MP3'
|
||||
input_audio, sample_rate = librosa.load(audio_path, sr=16000)
|
||||
# Pose video input
|
||||
pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'
|
||||
pose_video = VideoData(pose_video_path, height=height, width=width)
|
||||
|
||||
# Speech-to-video with pose
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_sample_rate=sample_rate,
|
||||
input_audio=input_audio,
|
||||
s2v_pose_video=pose_video,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)
|
||||
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
|
||||
pipe.load_lora(pipe.dit, "models/train/krea-realtime-video_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="a cat sitting on a boat",
|
||||
num_inference_steps=6, num_frames=81,
|
||||
seed=0, tiled=True,
|
||||
cfg_scale=1,
|
||||
sigma_shift=20,
|
||||
)
|
||||
save_video(video, "output.mp4", fps=15, quality=5)
|
||||
Reference in New Issue
Block a user