mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
@@ -166,10 +166,13 @@ class ModelConfig:
|
|||||||
offload_device: Optional[Union[str, torch.device]] = None
|
offload_device: Optional[Union[str, torch.device]] = None
|
||||||
offload_dtype: Optional[torch.dtype] = None
|
offload_dtype: Optional[torch.dtype] = None
|
||||||
|
|
||||||
def download_if_necessary(self, local_model_path="./models", skip_download=False):
|
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
|
||||||
if self.path is None:
|
if self.path is None:
|
||||||
if self.model_id is None or self.origin_file_pattern is None:
|
if self.model_id is None or self.origin_file_pattern is None:
|
||||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||||
|
if use_usp:
|
||||||
|
import torch.distributed as dist
|
||||||
|
skip_download = dist.get_rank() != 0
|
||||||
if not skip_download:
|
if not skip_download:
|
||||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
@@ -179,6 +182,9 @@ class ModelConfig:
|
|||||||
ignore_file_pattern=downloaded_files,
|
ignore_file_pattern=downloaded_files,
|
||||||
local_files_only=False
|
local_files_only=False
|
||||||
)
|
)
|
||||||
|
if use_usp:
|
||||||
|
import torch.distributed as dist
|
||||||
|
dist.barrier(device_ids=[dist.get_rank()])
|
||||||
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
||||||
if isinstance(self.path, list) and len(self.path) == 1:
|
if isinstance(self.path, list) and len(self.path) == 1:
|
||||||
self.path = self.path[0]
|
self.path = self.path[0]
|
||||||
@@ -425,19 +431,21 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
|
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
|
||||||
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
|
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
|
||||||
|
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
if use_usp: pipe.initialize_usp()
|
||||||
|
|
||||||
# Download and load models
|
# Download and load models
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
for model_config in model_configs:
|
for model_config in model_configs:
|
||||||
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp)
|
||||||
model_manager.load_model(
|
model_manager.load_model(
|
||||||
model_config.path,
|
model_config.path,
|
||||||
device=model_config.offload_device or device,
|
device=model_config.offload_device or device,
|
||||||
torch_dtype=model_config.offload_dtype or torch_dtype
|
torch_dtype=model_config.offload_dtype or torch_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize pipeline
|
# Load models
|
||||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
|
||||||
if use_usp: pipe.initialize_usp()
|
|
||||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
@@ -1148,17 +1156,20 @@ def model_fn_wan_video(
|
|||||||
else:
|
else:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
x = x + current_vace_hint * vace_scale
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
if reference_latents is not None:
|
|
||||||
x = x[:, reference_latents.shape[1]:]
|
|
||||||
f -= 1
|
|
||||||
|
|
||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
# Remove reference latents
|
||||||
|
if reference_latents is not None:
|
||||||
|
x = x[:, reference_latents.shape[1]:]
|
||||||
|
f -= 1
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ Wan supports multiple acceleration techniques, including:
|
|||||||
* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command:
|
* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install xfuser>=0.4.3
|
pip install "xfuser[flash-attn]>=0.4.3"
|
||||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ Wan 支持多种加速方案,包括
|
|||||||
* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用以下命令运行:
|
* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用以下命令运行:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install xfuser>=0.4.3
|
pip install "xfuser[flash-attn]>=0.4.3"
|
||||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user