mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
support wan2.2-animate-14b
This commit is contained in:
16
examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh
Normal file
16
examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh
Normal file
@@ -0,0 +1,16 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_animate.csv \
|
||||
--data_file_keys "video,animate_pose_video,animate_face_video" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 81 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.animate_adapter." \
|
||||
--output_path "./models/train/Wan2.2-Animate-14B_full" \
|
||||
--trainable_models "animate_adapter" \
|
||||
--extra_inputs "input_image,animate_pose_video,animate_face_video" \
|
||||
--use_gradient_checkpointing_offload
|
||||
20
examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh
Normal file
20
examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA
|
||||
# We tested on 8*80G GPUs
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_animate.csv \
|
||||
--data_file_keys "video,animate_pose_video,animate_face_video" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 81 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Animate-14B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,animate_pose_video,animate_face_video" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -2,7 +2,7 @@ import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, ImageCropAndResize, ToAbsolutePath
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -108,6 +108,9 @@ if __name__ == "__main__":
|
||||
time_division_factor=4,
|
||||
time_division_remainder=1,
|
||||
),
|
||||
special_operator_map={
|
||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16))
|
||||
}
|
||||
)
|
||||
model = WanTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Wan2.2-Animate-14B_full/epoch-1.safetensors")
|
||||
pipe.animate_adapter.load_state_dict(state_dict, strict=False)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0]
|
||||
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4]
|
||||
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4]
|
||||
video = pipe(
|
||||
prompt="视频中的人在做动作",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
animate_pose_video=animate_pose_video,
|
||||
animate_face_video=animate_face_video,
|
||||
num_frames=81, height=480, width=832,
|
||||
num_inference_steps=20, cfg_scale=1,
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.2-Animate-14B_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0]
|
||||
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4]
|
||||
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4]
|
||||
video = pipe(
|
||||
prompt="视频中的人在做动作",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
animate_pose_video=animate_pose_video,
|
||||
animate_face_video=animate_face_video,
|
||||
num_frames=81, height=480, width=832,
|
||||
num_inference_steps=20, cfg_scale=1,
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
||||
Reference in New Issue
Block a user