diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index a8efc23..8d447e8 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -44,8 +44,9 @@ class WanVideoPipeline(BasePipeline): self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None self.in_iteration_models = ("dit", "motion_controller", "vace") - self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), @@ -359,6 +360,10 @@ class WanVideoPipeline(BasePipeline): pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder") # Size division factor @@ -481,6 +486,7 @@ class WanVideoPipeline(BasePipeline): if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: self.load_models_to_device(self.in_iteration_models_2) models["dit"] = self.dit2 + models["vace"] = self.vace2 # Timestep timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) @@ -534,11 +540,12 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit): def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): length = (num_frames - 1) // 4 + 1 if vace_reference_image is not None: - length += 1 + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) if vace_reference_image is not None: - noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) return {"noise": noise} @@ -849,11 +856,22 @@ class WanVideoUnit_VACE(PipelineUnit): if vace_reference_image is None: pass else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + vace_reference_image = pipe.preprocess_video([vace_reference_image]) vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = torch.concat((*vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) - vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) return {"vace_context": vace_context, "vace_scale": vace_scale} diff --git a/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py b/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py new file mode 100644 index 0000000..a768192 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py @@ -0,0 +1,54 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video1_14b.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video2_14b.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video3_14b.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh new file mode 100644 index 0000000..0ee97da --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh @@ -0,0 +1,40 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh new file mode 100644 index 0000000..93b38cf --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh @@ -0,0 +1,43 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py new file mode 100644 index 0000000..e566dba --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors") +pipe.vace.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors") +pipe.vace2.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(17)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=17, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-VACE-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B-lora.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B-lora.py new file mode 100644 index 0000000..b6e6aff --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B-lora.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.vace, "models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.vace2, "models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(17)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=17, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5)