diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9086af2..4a7a2a2 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -166,10 +166,13 @@ class ModelConfig: offload_device: Optional[Union[str, torch.device]] = None offload_dtype: Optional[torch.dtype] = None - def download_if_necessary(self, local_model_path="./models", skip_download=False): + def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False): if self.path is None: if self.model_id is None or self.origin_file_pattern is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") + if use_usp: + import torch.distributed as dist + skip_download = dist.get_rank() != 0 if not skip_download: downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) snapshot_download( @@ -179,6 +182,9 @@ class ModelConfig: ignore_file_pattern=downloaded_files, local_files_only=False ) + if use_usp: + import torch.distributed as dist + dist.barrier(device_ids=[dist.get_rank()]) self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) if isinstance(self.path, list) and len(self.path) == 1: self.path = self.path[0] @@ -425,19 +431,21 @@ class WanVideoPipeline(BasePipeline): print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") model_config.model_id = redirect_dict[model_config.origin_file_pattern] + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: pipe.initialize_usp() + # Download and load models model_manager = ModelManager() for model_config in model_configs: - model_config.download_if_necessary(local_model_path, skip_download=skip_download) + model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp) model_manager.load_model( model_config.path, device=model_config.offload_device or device, torch_dtype=model_config.offload_dtype or torch_dtype ) - # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) - if use_usp: pipe.initialize_usp() + # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") pipe.dit = model_manager.fetch_model("wan_video_dit") pipe.vae = model_manager.fetch_model("wan_video_vae") @@ -1148,17 +1156,20 @@ def model_fn_wan_video( else: x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: - x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x) - if reference_latents is not None: - x = x[:, reference_latents.shape[1]:] - f -= 1 - x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 x = dit.unpatchify(x, (f, h, w)) return x diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 3d71a70..a774773 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -173,7 +173,7 @@ Wan supports multiple acceleration techniques, including: * **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command: ```shell -pip install xfuser>=0.4.3 +pip install "xfuser[flash-attn]>=0.4.3" torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py ``` diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index dd3b598..ba2c59c 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -175,7 +175,7 @@ Wan 支持多种加速方案,包括 * 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用以下命令运行: ```shell -pip install xfuser>=0.4.3 +pip install "xfuser[flash-attn]>=0.4.3" torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py ```