Merge pull request #1 from Eigensystem/fjr

fix some bugs
This commit is contained in:
Jinzhe Pan
2025-03-17 17:07:07 +08:00
committed by GitHub
5 changed files with 14 additions and 11 deletions

View File

View File

@@ -91,7 +91,6 @@ def rope_apply(x, freqs, num_heads):
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
print(f"x_out.shape: {x_out.shape}, freqs.shape: {freqs.shape}")
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)

View File

@@ -5,22 +5,22 @@ import torch.distributed as dist
# Download models
# snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
[
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
],
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
)
@@ -41,7 +41,10 @@ initialize_model_parallel(
)
torch.cuda.set_device(dist.get_rank())
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=f"cuda:{dist.get_rank()}", use_usp=True if dist.get_world_size() > 1 else False)
pipe = WanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device=f"cuda:{dist.get_rank()}",
use_usp=True if dist.get_world_size() > 1 else False)
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Text-to-video

View File

@@ -11,3 +11,4 @@ sentencepiece
protobuf
modelscope
ftfy
xfuser>=0.4.2