From d8b250607aac7764a5230c208b593ea5ae020c32 Mon Sep 17 00:00:00 2001 From: feifeibear Date: Mon, 17 Mar 2025 09:04:51 +0000 Subject: [PATCH] polish code --- diffsynth/models/wan_video_dit.py | 1 - examples/wanvideo/wan_14b_text_to_video.py | 23 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 618e7d8..a2c55e1 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -91,7 +91,6 @@ def rope_apply(x, freqs, num_heads): x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - print(f"x_out.shape: {x_out.shape}, freqs.shape: {freqs.shape}") x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py index d67e1d5..dcb2f29 100644 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -5,22 +5,22 @@ import torch.distributed as dist # Download models -# snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") +snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") # Load models model_manager = ModelManager(device="cpu") model_manager.load_models( [ [ - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", ], - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", ], torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. ) @@ -41,7 +41,10 @@ initialize_model_parallel( ) torch.cuda.set_device(dist.get_rank()) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=f"cuda:{dist.get_rank()}", use_usp=True if dist.get_world_size() > 1 else False) +pipe = WanVideoPipeline.from_model_manager(model_manager, + torch_dtype=torch.bfloat16, + device=f"cuda:{dist.get_rank()}", + use_usp=True if dist.get_world_size() > 1 else False) pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. # Text-to-video