From 8ea45b0daa69852d1d96f9563a454f200bed0033 Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Tue, 21 Oct 2025 10:34:48 +0800 Subject: [PATCH 1/3] update wans2v training --- README.md | 2 +- README_zh.md | 2 +- diffsynth/trainers/unified_dataset.py | 7 +++ diffsynth/trainers/utils.py | 1 + download.py | 3 ++ examples/wanvideo/README.md | 2 +- examples/wanvideo/README_zh.md | 2 +- .../model_training/full/Wan2.2-S2V-14B.sh | 17 ++++++ .../model_training/lora/Wan2.2-S2V-14B.sh | 19 +++++++ examples/wanvideo/model_training/train.py | 18 +++++-- .../validate_full/Wan2.2-S2V-14B.py | 53 +++++++++++++++++++ .../validate_lora/Wan2.2-S2V-14B.py | 51 ++++++++++++++++++ 12 files changed, 169 insertions(+), 8 deletions(-) create mode 100644 download.py create mode 100644 examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py diff --git a/README.md b/README.md index ceac65d..b06a796 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training | |-|-|-|-|-|-|-| |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/README_zh.md b/README_zh.md index 2639c6f..8642ce0 100644 --- a/README_zh.md +++ b/README_zh.md @@ -208,7 +208,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/diffsynth/trainers/unified_dataset.py b/diffsynth/trainers/unified_dataset.py index 0083f44..feea784 100644 --- a/diffsynth/trainers/unified_dataset.py +++ b/diffsynth/trainers/unified_dataset.py @@ -225,6 +225,13 @@ class ToAbsolutePath(DataProcessingOperator): def __call__(self, data): return os.path.join(self.base_path, data) +class LoadAudio(DataProcessingOperator): + def __init__(self, sr=16000): + self.sr = sr + def __call__(self, data: str): + import librosa + input_audio, sample_rate = librosa.load(data, sr=self.sr) + return {'input_audio':input_audio, 'sample_rate':sample_rate} class UnifiedDataset(torch.utils.data.Dataset): diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 3262d15..da76509 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -603,6 +603,7 @@ def wan_parser(): parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") + parser.add_argument("--audio_processor_config", type=str, default=None, help="Model ID with origin paths to the audio processor config, e.g., Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") diff --git a/download.py b/download.py new file mode 100644 index 0000000..be68de8 --- /dev/null +++ b/download.py @@ -0,0 +1,3 @@ +#Model Download +from modelscope import snapshot_download +model_dir = snapshot_download('Wan-AI/Wan2.2-S2V-14B',local_dir='./models/Wan-AI/Wan2.2-S2V-14B') \ No newline at end of file diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index a45c287..efb4274 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -49,7 +49,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-| |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./model_training/full/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./model_training/lora/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_lora/Wan2.2-S2V-14B.py)| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index f3d8eae..2d8cbff 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -49,7 +49,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./model_training/full/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./model_training/lora/Wan2.2-S2V-14B.sh)|[code](./model_training/validate_lora/Wan2.2-S2V-14B.py)| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh new file mode 100644 index 0000000..3a9a871 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh @@ -0,0 +1,17 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset/wans2v \ + --dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \ + --data_file_keys "video,input_audio,s2v_pose_video" \ + --height 448 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ + --audio_processor_config "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --trainable_models "dit" \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-S2V-14B_full" \ + --extra_inputs "input_image,input_audio,s2v_pose_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh new file mode 100644 index 0000000..84723ec --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh @@ -0,0 +1,19 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset/wans2v \ + --dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \ + --data_file_keys "video,input_audio,s2v_pose_video" \ + --height 448 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ + --audio_processor_config "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-S2V-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,input_audio,s2v_pose_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index f31ad69..010f581 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -2,7 +2,7 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser -from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, ImageCropAndResize, ToAbsolutePath +from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -10,7 +10,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" class WanTrainingModule(DiffusionTrainingModule): def __init__( self, - model_paths=None, model_id_with_origin_paths=None, + model_paths=None, model_id_with_origin_paths=None, audio_processor_config=None, trainable_models=None, lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None, use_gradient_checkpointing=True, @@ -22,7 +22,9 @@ class WanTrainingModule(DiffusionTrainingModule): super().__init__() # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) - self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) + if audio_processor_config is not None: + audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1]) + self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, audio_processor_config=audio_processor_config) # Training mode self.switch_pipe_to_training_mode( @@ -52,6 +54,9 @@ class WanTrainingModule(DiffusionTrainingModule): "height": data["video"][0].size[1], "width": data["video"][0].size[0], "num_frames": len(data["video"]), + "audio_embeds":None, + "s2v_pose_latents":None, + "motion_video":None, # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, @@ -73,6 +78,9 @@ class WanTrainingModule(DiffusionTrainingModule): inputs_shared["end_image"] = data["video"][-1] elif extra_input == "reference_image" or extra_input == "vace_reference_image": inputs_shared[extra_input] = data[extra_input][0] + elif extra_input == "input_audio": + inputs_shared['input_audio'] = data['input_audio']['input_audio'] + inputs_shared['sample_rate'] = data['input_audio']['sample_rate'] else: inputs_shared[extra_input] = data[extra_input] @@ -109,12 +117,14 @@ if __name__ == "__main__": time_division_remainder=1, ), special_operator_map={ - "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)) + "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), + 'input_audio': ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000), } ) model = WanTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, + audio_processor_config=args.audio_processor_config, trainable_models=args.trainable_models, lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py new file mode 100644 index 0000000..b69a575 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py @@ -0,0 +1,53 @@ +import torch +from PIL import Image +import librosa +from diffsynth import VideoData, save_video_with_audio, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), +) + +state_dict = load_state_dict("models/train/Wan2.2-S2V-14B_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict, strict=False) +pipe.enable_vram_management() + + +num_frames = 81 # 4n+1 +height = 448 +width = 832 + +prompt = "a person is singing" +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) +# s2v audio input, recommend 16kHz sampling rate +audio_path = 'data/example_video_dataset/wans2v/sing.MP3' +input_audio, sample_rate = librosa.load(audio_path, sr=16000) +# S2V pose video input +pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4' +pose_video = VideoData(pose_video_path, height=height, width=width) + +# Speech-to-video with pose +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + s2v_pose_video=pose_video, + num_inference_steps=40, +) +save_video_with_audio(video[1:], "video_pose_with_audio_full.mp4", audio_path, fps=16, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py new file mode 100644 index 0000000..f8245d1 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py @@ -0,0 +1,51 @@ +import torch +from PIL import Image +import librosa +from diffsynth import VideoData, save_video_with_audio +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda:0", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), +) + +pipe.load_lora(pipe.dit, "models/train/Wan2.2-S2V-14B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + + +num_frames = 81 # 4n+1 +height = 448 +width = 832 + +prompt = "a person is singing" +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) +# s2v audio input, recommend 16kHz sampling rate +audio_path = 'data/example_video_dataset/wans2v/sing.MP3' +input_audio, sample_rate = librosa.load(audio_path, sr=16000) +# Pose video input +pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4' +pose_video = VideoData(pose_video_path, height=height, width=width) + +# Speech-to-video with pose +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + s2v_pose_video=pose_video, + num_inference_steps=40, +) +save_video_with_audio(video[1:], "video_pose_with_audio_lora.mp4", audio_path, fps=16, quality=5) From b168d7aa8b1d9e96697da00efedf2bf842f443a0 Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Tue, 21 Oct 2025 10:39:30 +0800 Subject: [PATCH 2/3] update wans2v training --- download.py | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 download.py diff --git a/download.py b/download.py deleted file mode 100644 index be68de8..0000000 --- a/download.py +++ /dev/null @@ -1,3 +0,0 @@ -#Model Download -from modelscope import snapshot_download -model_dir = snapshot_download('Wan-AI/Wan2.2-S2V-14B',local_dir='./models/Wan-AI/Wan2.2-S2V-14B') \ No newline at end of file From 30292d94111dcb14ffcae1d923f2986862c35fa9 Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Tue, 21 Oct 2025 19:59:44 +0800 Subject: [PATCH 3/3] update wan2.2-S2V training --- diffsynth/pipelines/wan_video_new.py | 4 ++-- diffsynth/trainers/unified_dataset.py | 2 +- examples/wanvideo/model_training/train.py | 8 +------- .../model_training/validate_full/Wan2.2-S2V-14B.py | 2 +- .../model_training/validate_lora/Wan2.2-S2V-14B.py | 4 ++-- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index c9342ea..141660f 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1028,8 +1028,8 @@ class WanVideoUnit_S2V(PipelineUnit): if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: return inputs_shared, inputs_posi, inputs_nega num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") - input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate") - s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) inputs_posi.update(audio_input_positive) diff --git a/diffsynth/trainers/unified_dataset.py b/diffsynth/trainers/unified_dataset.py index feea784..c98a160 100644 --- a/diffsynth/trainers/unified_dataset.py +++ b/diffsynth/trainers/unified_dataset.py @@ -231,7 +231,7 @@ class LoadAudio(DataProcessingOperator): def __call__(self, data: str): import librosa input_audio, sample_rate = librosa.load(data, sr=self.sr) - return {'input_audio':input_audio, 'sample_rate':sample_rate} + return input_audio class UnifiedDataset(torch.utils.data.Dataset): diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 010f581..643c8e2 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -54,9 +54,6 @@ class WanTrainingModule(DiffusionTrainingModule): "height": data["video"][0].size[1], "width": data["video"][0].size[0], "num_frames": len(data["video"]), - "audio_embeds":None, - "s2v_pose_latents":None, - "motion_video":None, # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, @@ -78,9 +75,6 @@ class WanTrainingModule(DiffusionTrainingModule): inputs_shared["end_image"] = data["video"][-1] elif extra_input == "reference_image" or extra_input == "vace_reference_image": inputs_shared[extra_input] = data[extra_input][0] - elif extra_input == "input_audio": - inputs_shared['input_audio'] = data['input_audio']['input_audio'] - inputs_shared['sample_rate'] = data['input_audio']['sample_rate'] else: inputs_shared[extra_input] = data[extra_input] @@ -118,7 +112,7 @@ if __name__ == "__main__": ), special_operator_map={ "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), - 'input_audio': ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000), + "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000), } ) model = WanTrainingModule( diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py index b69a575..2df08d2 100644 --- a/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py @@ -50,4 +50,4 @@ video = pipe( s2v_pose_video=pose_video, num_inference_steps=40, ) -save_video_with_audio(video[1:], "video_pose_with_audio_full.mp4", audio_path, fps=16, quality=5) +save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py index f8245d1..a6166b9 100644 --- a/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py @@ -6,7 +6,7 @@ from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, - device="cuda:0", + device="cuda", model_configs=[ ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), @@ -48,4 +48,4 @@ video = pipe( s2v_pose_video=pose_video, num_inference_steps=40, ) -save_video_with_audio(video[1:], "video_pose_with_audio_lora.mp4", audio_path, fps=16, quality=5) +save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)