Merge pull request #933 from lzws/main

update wan2.2-VACE-Fun-A14B
This commit is contained in:
Zhongjie Duan
2025-09-24 09:56:37 +08:00
committed by GitHub
6 changed files with 224 additions and 5 deletions

View File

@@ -44,8 +44,9 @@ class WanVideoPipeline(BasePipeline):
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.vace2: VaceWanModel = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
@@ -359,6 +360,10 @@ class WanVideoPipeline(BasePipeline):
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace")
if isinstance(vace, list):
pipe.vace, pipe.vace2 = vace
else:
pipe.vace = vace
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
# Size division factor
@@ -481,6 +486,7 @@ class WanVideoPipeline(BasePipeline):
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["dit"] = self.dit2
models["vace"] = self.vace2
# Timestep
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
@@ -534,11 +540,12 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
length = (num_frames - 1) // 4 + 1
if vace_reference_image is not None:
length += 1
f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1
length += f
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
if vace_reference_image is not None:
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
return {"noise": noise}
@@ -849,11 +856,22 @@ class WanVideoUnit_VACE(PipelineUnit):
if vace_reference_image is None:
pass
else:
if not isinstance(vace_reference_image,list):
vace_reference_image = [vace_reference_image]
vace_reference_image = pipe.preprocess_video(vace_reference_image)
bs, c, f, h, w = vace_reference_image.shape
new_vace_ref_images = []
for j in range(f):
new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])
vace_reference_image = new_vace_ref_images
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_reference_latents = torch.concat((*vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
return {"vace_context": vace_context, "vace_scale": vace_scale}

View File

@@ -0,0 +1,54 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
)
# Depth video -> Video
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_video=control_video,
seed=1, tiled=True
)
save_video(video, "video1_14b.mp4", fps=15, quality=5)
# Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video2_14b.mp4", fps=15, quality=5)
# Depth video + Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_video=control_video,
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video3_14b.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,40 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
--data_file_keys "video,vace_video,vace_reference_image" \
--height 480 \
--width 832 \
--num_frames 17 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.vace." \
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full" \
--trainable_models "vace" \
--extra_inputs "vace_video,vace_reference_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
--data_file_keys "video,vace_video,vace_reference_image" \
--height 480 \
--width 832 \
--num_frames 17 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.vace." \
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full" \
--trainable_models "vace" \
--extra_inputs "vace_video,vace_reference_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358
# boundary corresponds to timesteps [0, 900]

View File

@@ -0,0 +1,43 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
--data_file_keys "video,vace_video,vace_reference_image" \
--height 480 \
--width 832 \
--num_frames 17 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.vace." \
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora" \
--lora_base_model "vace" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "vace_video,vace_reference_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
--data_file_keys "video,vace_video,vace_reference_image" \
--height 480 \
--width 832 \
--num_frames 17 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.vace." \
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora" \
--lora_base_model "vace" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "vace_video,vace_reference_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358
# boundary corresponds to timesteps [0, 900]

View File

@@ -0,0 +1,33 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors")
pipe.vace.load_state_dict(state_dict)
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors")
pipe.vace2.load_state_dict(state_dict)
pipe.enable_vram_management()
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
video = [video[i] for i in range(17)]
reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_video=video, vace_reference_image=reference_image, num_frames=17,
seed=1, tiled=True
)
save_video(video, "video_Wan2.2-VACE-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,31 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.vace, "models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
pipe.load_lora(pipe.vace2, "models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
video = [video[i] for i in range(17)]
reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_video=video, vace_reference_image=reference_image, num_frames=17,
seed=1, tiled=True
)
save_video(video, "video_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5)