diff --git a/.gitignore b/.gitignore index 391b448..6fd0d8e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ /data /models /scripts +/diffusers *.pkl *.safetensors *.pth diff --git a/examples/flux/model_inference/FLEX.2-preview.py b/examples/flux/model_inference/FLEX.2-preview.py index 2689679..efc8e91 100644 --- a/examples/flux/model_inference/FLEX.2-preview.py +++ b/examples/flux/model_inference/FLEX.2-preview.py @@ -21,12 +21,12 @@ image = pipe( num_inference_steps=50, embedded_guidance=3.5, seed=0 ) -image.save(f"image_1.jpg") +image.save("image_1.jpg") mask = np.zeros((1024, 1024, 3), dtype=np.uint8) mask[200:400, 400:700] = 255 mask = Image.fromarray(mask) -mask.save(f"image_mask.jpg") +mask.save("image_mask.jpg") inpaint_image = image @@ -36,7 +36,7 @@ image = pipe( flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, seed=4 ) -image.save(f"image_2_new.jpg") +image.save("image_2.jpg") control_image = Annotator("canny")(image) control_image.save("image_control.jpg") @@ -47,4 +47,4 @@ image = pipe( flex_control_image=control_image, seed=4 ) -image.save(f"image_3_new.jpg") +image.save("image_3.jpg") diff --git a/examples/flux/model_inference_low_vram/FLEX.2-preview.py b/examples/flux/model_inference_low_vram/FLEX.2-preview.py index d3071f5..a4454e8 100644 --- a/examples/flux/model_inference_low_vram/FLEX.2-preview.py +++ b/examples/flux/model_inference_low_vram/FLEX.2-preview.py @@ -32,12 +32,12 @@ image = pipe( num_inference_steps=50, embedded_guidance=3.5, seed=0 ) -image.save(f"image_1.jpg") +image.save("image_1.jpg") mask = np.zeros((1024, 1024, 3), dtype=np.uint8) mask[200:400, 400:700] = 255 mask = Image.fromarray(mask) -mask.save(f"image_mask.jpg") +mask.save("image_mask.jpg") inpaint_image = image @@ -47,7 +47,7 @@ image = pipe( flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, seed=4 ) -image.save(f"image_2_new.jpg") +image.save("image_2.jpg") control_image = Annotator("canny")(image) control_image.save("image_control.jpg") @@ -58,4 +58,4 @@ image = pipe( flex_control_image=control_image, seed=4 ) -image.save(f"image_3_new.jpg") +image.save("image_3.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py index 6491515..5c1d206 100644 --- a/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py @@ -14,12 +14,11 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"), ], ) -pipe.enable_lora_magic() state_dict = load_state_dict("models/train/FLUX.1-dev-LoRA-Encoder_full/epoch-0.safetensors") pipe.lora_encoder.load_state_dict(state_dict) lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") -pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA. +pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA. image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) image.save("image_FLUX.1-dev-LoRA-Encoder_full.jpg") diff --git a/examples/flux/model_training/validate_full/Step1X-Edit.py b/examples/flux/model_training/validate_full/Step1X-Edit.py index 054e7fb..feaac7a 100644 --- a/examples/flux/model_training/validate_full/Step1X-Edit.py +++ b/examples/flux/model_training/validate_full/Step1X-Edit.py @@ -8,7 +8,7 @@ pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"), + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), ], diff --git a/examples/flux/model_training/validate_lora/FLEX.2-preview.py b/examples/flux/model_training/validate_lora/FLEX.2-preview.py index 6a6a60d..a905918 100644 --- a/examples/flux/model_training/validate_lora/FLEX.2-preview.py +++ b/examples/flux/model_training/validate_lora/FLEX.2-preview.py @@ -6,7 +6,7 @@ pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), + ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py index d6fc920..f0f5941 100644 --- a/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py @@ -2,7 +2,7 @@ from PIL import Image import torch from modelscope import dataset_snapshot_download, snapshot_download from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig -from diffsynth.controlnets.processors import Annotator +from diffsynth.utils.controlnet import Annotator allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"] snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern) diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh index f563fd1..e369223 100644 --- a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh +++ b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh @@ -3,9 +3,9 @@ accelerate launch examples/qwen_image/model_training/train.py \ --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \ --data_file_keys "image,blockwise_controlnet_image" \ --max_pixels 1048576 \ - --dataset_repeat 50 \ + --dataset_repeat 400 \ --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \ - --learning_rate 1e-4 \ + --learning_rate 1e-3 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \ diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh index 2bd2926..93313ec 100644 --- a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh +++ b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh @@ -3,9 +3,9 @@ accelerate launch examples/qwen_image/model_training/train.py \ --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \ --data_file_keys "image,blockwise_controlnet_image" \ --max_pixels 1048576 \ - --dataset_repeat 50 \ + --dataset_repeat 400 \ --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \ - --learning_rate 1e-4 \ + --learning_rate 1e-3 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \ diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh index b87552b..99b25ad 100644 --- a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh +++ b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh @@ -3,9 +3,9 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \ --data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \ --max_pixels 1048576 \ - --dataset_repeat 50 \ + --dataset_repeat 400 \ --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \ - --learning_rate 1e-4 \ + --learning_rate 1e-3 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \ diff --git a/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py b/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py index 089433d..10566fa 100644 --- a/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py @@ -1,5 +1,6 @@ # Without VRAM Management, 80G VRAM is not enough to run this example. # We recommend to use `examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py`. +# CPU Offload is enabled in this example. import torch from PIL import Image from diffsynth.utils.data import save_video, VideoData @@ -7,16 +8,27 @@ from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig from modelscope import dataset_snapshot_download +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"), - ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"), - ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), - ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), ], tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, ) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py index 6397044..180482c 100644 --- a/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py @@ -53,6 +53,7 @@ save_video(video, "video_1_Wan2.2-Animate-14B.mp4", fps=15, quality=5) # Replace snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B") lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.bfloat16, device="cuda")["state_dict"] +lora_state_dict = {i: lora_state_dict[i].to(torch.bfloat16) for i in lora_state_dict} pipe.load_lora(pipe.dit, state_dict=lora_state_dict) input_image = Image.open("data/examples/wan/animate/replace_input_image.png") animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4] diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py index 0933c31..3474b01 100644 --- a/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py @@ -1,5 +1,3 @@ -# Without VRAM Management, 80G VRAM is not enough to run this example. -# We recommend to use `examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py`. import torch from PIL import Image from diffsynth.utils.data import save_video, VideoData diff --git a/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh index c9f6f64..baf98a9 100644 --- a/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh +++ b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh @@ -10,4 +10,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \ --trainable_models "dit" \ - --extra_inputs "input_image,end_image" \ No newline at end of file + --extra_inputs "input_image,end_image" \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh index 6d257ff..492898b 100644 --- a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh @@ -10,4 +10,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.1-I2V-14B-480P_full" \ --trainable_models "dit" \ - --extra_inputs "input_image" \ No newline at end of file + --extra_inputs "input_image" \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh index bbb2870..1d91359 100644 --- a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh @@ -12,4 +12,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --output_path "./models/train/Wan2.1-I2V-14B-720P_full" \ --trainable_models "dit" \ --extra_inputs "input_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh index fb4d18c..10c4a5a 100644 --- a/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh +++ b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh @@ -1,6 +1,6 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset/wans2v \ - --dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \ --data_file_keys "video,input_audio,s2v_pose_video" \ --height 448 \ --width 832 \ diff --git a/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh index 0ee97da..ecfef32 100644 --- a/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh +++ b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh @@ -15,7 +15,8 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --extra_inputs "vace_video,vace_reference_image" \ --use_gradient_checkpointing_offload \ --max_timestep_boundary 0.358 \ - --min_timestep_boundary 0 + --min_timestep_boundary 0 \ + --initialize_model_on_cpu # boundary corresponds to timesteps [900, 1000] @@ -36,5 +37,6 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --extra_inputs "vace_video,vace_reference_image" \ --use_gradient_checkpointing_offload \ --max_timestep_boundary 1 \ - --min_timestep_boundary 0.358 + --min_timestep_boundary 0.358 \ + --initialize_model_on_cpu # boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh index ec987a8..52b72bd 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh @@ -1,9 +1,10 @@ -accelerate launch examples/wanvideo/model_training/train.py \ +# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA +# We tested on 8*80G GPUs +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ --dataset_metadata_path data/example_video_dataset/metadata.csv \ --height 720 \ --width 1280 \ - --num_frames 49 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ --learning_rate 1e-4 \ @@ -14,4 +15,5 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "input_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh index 3865965..510796b 100644 --- a/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh @@ -1,6 +1,6 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset/wans2v \ - --dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \ --data_file_keys "video,input_audio,s2v_pose_video" \ --height 448 \ --width 832 \ diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 85cf6fc..4973438 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -116,6 +116,7 @@ def wan_parser(): parser.add_argument("--audio_processor_path", type=str, default=None, help="Path to the audio processor. If provided, the processor will be used for Wan2.2-S2V model.") parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser @@ -165,7 +166,7 @@ if __name__ == "__main__": fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, - device=accelerator.device, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, max_timestep_boundary=args.max_timestep_boundary, min_timestep_boundary=args.min_timestep_boundary, ) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py index 33ac71a..38aa34c 100644 --- a/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py @@ -5,19 +5,29 @@ from diffsynth.core import load_state_dict from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"), ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"), - ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), - ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), ], ) -state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors") +state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu") pipe.vace.load_state_dict(state_dict) -state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors") +state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu") pipe.vace2.load_state_dict(state_dict) video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)