Merge branch 'main' into wan-itv-train

This commit is contained in:
Zhongjie Duan
2025-03-13 15:24:59 +08:00
committed by GitHub
58 changed files with 7620 additions and 745 deletions

View File

@@ -9,6 +9,10 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
torch_dtype=torch.float32, # Image Encoder is loaded with float32
)
model_manager.load_models(
[
[
@@ -20,14 +24,13 @@ model_manager.load_models(
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
],
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Download example image
dataset_snapshot_download(