mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
0
diffsynth/distributed/__init__.py
Normal file
0
diffsynth/distributed/__init__.py
Normal file
Binary file not shown.
@@ -91,7 +91,6 @@ def rope_apply(x, freqs, num_heads):
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
|
||||
print(f"x_out.shape: {x_out.shape}, freqs.shape: {freqs.shape}")
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
@@ -5,22 +5,22 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
# Download models
|
||||
# snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
],
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"/demo-huabei2/models/dit/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||
)
|
||||
@@ -41,7 +41,10 @@ initialize_model_parallel(
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=f"cuda:{dist.get_rank()}", use_usp=True if dist.get_world_size() > 1 else False)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=f"cuda:{dist.get_rank()}",
|
||||
use_usp=True if dist.get_world_size() > 1 else False)
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
|
||||
# Text-to-video
|
||||
|
||||
@@ -11,3 +11,4 @@ sentencepiece
|
||||
protobuf
|
||||
modelscope
|
||||
ftfy
|
||||
xfuser>=0.4.2
|
||||
|
||||
Reference in New Issue
Block a user