From 2883bc1b763523c34a2fab10a282a81fa0917288 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 15 Dec 2025 15:48:42 +0800 Subject: [PATCH 1/2] support ascend npu --- diffsynth/core/__init__.py | 1 + diffsynth/core/device/__init__.py | 1 + .../core/device/npu_compatible_device.py | 107 ++++++++++++++++++ diffsynth/core/vram/layers.py | 4 +- diffsynth/diffusion/base_pipeline.py | 7 +- diffsynth/pipelines/wan_video.py | 2 +- .../utils/xfuser/xdit_context_parallel.py | 9 +- docs/en/Pipeline_Usage/GPU_support.md | 58 ++++++++++ docs/en/Pipeline_Usage/Setup.md | 2 + docs/zh/Pipeline_Usage/GPU_support.md | 58 ++++++++++ docs/zh/Pipeline_Usage/Setup.md | 2 + 11 files changed, 242 insertions(+), 9 deletions(-) create mode 100644 diffsynth/core/device/__init__.py create mode 100644 diffsynth/core/device/npu_compatible_device.py create mode 100644 docs/en/Pipeline_Usage/GPU_support.md create mode 100644 docs/zh/Pipeline_Usage/GPU_support.md diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py index 72e501f..6c0a6c8 100644 --- a/diffsynth/core/__init__.py +++ b/diffsynth/core/__init__.py @@ -3,3 +3,4 @@ from .data import * from .gradient import * from .loader import * from .vram import * +from .device import * diff --git a/diffsynth/core/device/__init__.py b/diffsynth/core/device/__init__.py new file mode 100644 index 0000000..e53364f --- /dev/null +++ b/diffsynth/core/device/__init__.py @@ -0,0 +1 @@ +from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type \ No newline at end of file diff --git a/diffsynth/core/device/npu_compatible_device.py b/diffsynth/core/device/npu_compatible_device.py new file mode 100644 index 0000000..d96b8fb --- /dev/null +++ b/diffsynth/core/device/npu_compatible_device.py @@ -0,0 +1,107 @@ +import importlib +import torch +from typing import Any + + +def is_torch_npu_available(): + return importlib.util.find_spec("torch_npu") is not None + + +IS_CUDA_AVAILABLE = torch.cuda.is_available() +IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available() + +if IS_NPU_AVAILABLE: + import torch_npu + + torch.npu.config.allow_internal_format = False + + +def get_device_type() -> str: + """Get device type based on current machine, currently only support CPU, CUDA, NPU.""" + if IS_CUDA_AVAILABLE: + device = "cuda" + elif IS_NPU_AVAILABLE: + device = "npu" + else: + device = "cpu" + + return device + + +def get_torch_device() -> Any: + """Get torch attribute based on device type, e.g. torch.cuda or torch.npu""" + device_name = get_device_type() + + try: + return getattr(torch, device_name) + except AttributeError: + print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.") + return torch.cuda + + +def get_device_id() -> int: + """Get current device id based on device type.""" + return get_torch_device().current_device() + + +def get_device_name() -> str: + """Get current device name based on device type.""" + return f"{get_device_type()}:{get_device_id()}" + + +def synchronize() -> None: + """Execute torch synchronize operation.""" + get_torch_device().synchronize() + + +def empty_cache() -> None: + """Execute torch empty cache operation.""" + get_torch_device().empty_cache() + + +def get_nccl_backend() -> str: + """Return distributed communication backend type based on device type.""" + if IS_CUDA_AVAILABLE: + return "nccl" + elif IS_NPU_AVAILABLE: + return "hccl" + else: + raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.") + + +def enable_high_precision_for_bf16(): + """ + Set high accumulation dtype for matmul and reduction. + """ + if IS_CUDA_AVAILABLE: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + if IS_NPU_AVAILABLE: + torch.npu.matmul.allow_tf32 = False + torch.npu.matmul.allow_bf16_reduced_precision_reduction = False + + +def parse_device_type(device): + if isinstance(device, str): + if device.startswith("cuda"): + return "cuda" + elif device.startswith("npu"): + return "npu" + else: + return "cpu" + elif isinstance(device, torch.device): + return device.type + + +def parse_nccl_backend(device_type): + if device_type == "cuda": + return "nccl" + elif device_type == "npu": + return "hccl" + else: + raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.") + + +def get_available_device_type(): + return get_device_type() diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py index 852a8f3..01ade0e 100644 --- a/diffsynth/core/vram/layers.py +++ b/diffsynth/core/vram/layers.py @@ -2,6 +2,7 @@ import torch, copy from typing import Union from .initialization import skip_model_initialization from .disk_map import DiskMap +from ..device import parse_device_type class AutoTorchModule(torch.nn.Module): @@ -32,6 +33,7 @@ class AutoTorchModule(torch.nn.Module): ) self.state = 0 self.name = "" + self.computation_device_type = parse_device_type(self.computation_device) def set_dtype_and_device( self, @@ -61,7 +63,7 @@ class AutoTorchModule(torch.nn.Module): return r def check_free_vram(self): - gpu_mem_state = torch.cuda.mem_get_info(self.computation_device) + gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device) used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) return used_memory < self.vram_limit diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 2bec693..0140497 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -3,7 +3,7 @@ import torch import numpy as np from einops import repeat, reduce from typing import Union -from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig +from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput @@ -68,6 +68,7 @@ class BasePipeline(torch.nn.Module): # The device and torch_dtype is used for the storage of intermediate variables, not models. self.device = device self.torch_dtype = torch_dtype + self.device_type = parse_device_type(device) # The following parameters are used for shape check. self.height_division_factor = height_division_factor self.width_division_factor = width_division_factor @@ -154,7 +155,7 @@ class BasePipeline(torch.nn.Module): for module in model.modules(): if hasattr(module, "offload"): module.offload() - torch.cuda.empty_cache() + getattr(torch, self.device_type).empty_cache() # onload models for name, model in self.named_children(): if name in model_names: @@ -176,7 +177,7 @@ class BasePipeline(torch.nn.Module): def get_vram(self): - return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) + return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3) def get_module(self, model, name): if "." in name: diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index fa43db1..5e55abe 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -126,7 +126,7 @@ class WanVideoPipeline(BasePipeline): pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) if use_usp: from ..utils.xfuser import initialize_usp - initialize_usp() + initialize_usp(device) model_pool = pipe.download_and_load_models(model_configs, vram_limit) # Fetch models diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 1173313..b7fa72d 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -5,19 +5,20 @@ from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention +from ...core.device import parse_nccl_backend, parse_device_type -def initialize_usp(): +def initialize_usp(device_type): import torch.distributed as dist from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment - dist.init_process_group(backend="nccl", init_method="env://") + dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://") init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), ring_degree=1, ulysses_degree=dist.get_world_size(), ) - torch.cuda.set_device(dist.get_rank()) + getattr(torch, device_type).set_device(dist.get_rank()) def sinusoidal_embedding_1d(dim, position): @@ -141,5 +142,5 @@ def usp_attn_forward(self, x, freqs): x = x.flatten(2) del q, k, v - torch.cuda.empty_cache() + getattr(torch, parse_device_type(x.device)).empty_cache() return self.o(x) \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/GPU_support.md b/docs/en/Pipeline_Usage/GPU_support.md new file mode 100644 index 0000000..2f206eb --- /dev/null +++ b/docs/en/Pipeline_Usage/GPU_support.md @@ -0,0 +1,58 @@ +# GPU/NPU Support + +`DiffSynth-Studio` supports various GPUs and NPUs. This document explains how to run model inference and training on these devices. + +Before you begin, please follow the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies. + +## NVIDIA GPU + +All sample code provided by this project supports NVIDIA GPUs by default, requiring no additional modifications. + +## AMD GPU + +AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions. + +## Ascend NPU + +When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code. + +For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU: + +```diff +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, +- "preparing_device": "cuda", ++ "preparing_device": "npu", + "computation_dtype": torch.bfloat16, +- "computation_device": "cuda", ++ "computation_device": "npu", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, +- device="cuda", ++ device="npu", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, ++ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, +) + +video = pipe( + prompt="Documentary-style photography: a lively puppy running swiftly across lush green grass. The puppy has brownish-yellow fur, upright ears, and an alert, joyful expression. Sunlight bathes its body, making the fur appear exceptionally soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky with scattered white clouds in the distance. Strong perspective captures the motion of the running puppy and the vitality of the surrounding grass. Mid-shot, side-moving viewpoint.", + negative_prompt="Overly vibrant colors, overexposed, static, blurry details, subtitles, artistic style, painting, still image, overall grayish tone, worst quality, low quality, JPEG artifacts, ugly, distorted, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, fused fingers, motionless scene, cluttered background, three legs, many people in background, walking backward", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) +``` diff --git a/docs/en/Pipeline_Usage/Setup.md b/docs/en/Pipeline_Usage/Setup.md index 20e25d9..c95e5b7 100644 --- a/docs/en/Pipeline_Usage/Setup.md +++ b/docs/en/Pipeline_Usage/Setup.md @@ -36,6 +36,8 @@ Ascend NPU support is provided via the `torch-npu` package. Taking version `2.1. pip install torch-npu==2.1.0.post17 ``` +When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu). + ## Other Installation Issues If you encounter issues during installation, they may be caused by upstream dependencies. Please refer to the documentation for these packages: diff --git a/docs/zh/Pipeline_Usage/GPU_support.md b/docs/zh/Pipeline_Usage/GPU_support.md new file mode 100644 index 0000000..3ba76fc --- /dev/null +++ b/docs/zh/Pipeline_Usage/GPU_support.md @@ -0,0 +1,58 @@ +# GPU/NPU 支持 + +`DiffSynth-Studio` 支持多种 GPU/NPU,本文介绍如何在这些设备上运行模型推理和训练。 + +在开始前,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)安装好 GPU/NPU 相关的依赖包。 + +## NVIDIA GPU + +本项目提供的所有样例代码默认支持 NVIDIA GPU,无需额外修改。 + +## AMD GPU + +AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。 + +## Ascend NPU + +使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"`。 + +例如,Wan2.1-T2V-1.3B 的推理代码: + +```diff +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, +- "preparing_device": "cuda", ++ "preparing_device": "npu", + "computation_dtype": torch.bfloat16, +- "computation_device": "cuda", ++ "preparing_device": "npu", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, +- device="cuda", ++ device="npu", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, ++ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) +``` diff --git a/docs/zh/Pipeline_Usage/Setup.md b/docs/zh/Pipeline_Usage/Setup.md index d9efc1f..0c99840 100644 --- a/docs/zh/Pipeline_Usage/Setup.md +++ b/docs/zh/Pipeline_Usage/Setup.md @@ -36,6 +36,8 @@ Ascend NPU 通过 `torch-npu` 包提供支持,以 `2.1.0.post17` 版本(本 pip install torch-npu==2.1.0.post17 ``` +使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。 + ## 其他安装问题 如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档: From 7c6905a4322bf6691cc50353912fbc80adbe1a61 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 15 Dec 2025 15:50:12 +0800 Subject: [PATCH 2/2] support ascend npu --- docs/en/README.md | 1 + docs/zh/README.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/README.md b/docs/en/README.md index 17c8e88..39ae439 100644 --- a/docs/en/README.md +++ b/docs/en/README.md @@ -31,6 +31,7 @@ This section introduces the basic usage of `DiffSynth-Studio`, including how to * [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) * [Model Training](/docs/en/Pipeline_Usage/Model_Training.md) * [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md) +* [GPU/NPU Support](/docs/en/Pipeline_Usage/GPU_support.md) ## Section 2: Model Details diff --git a/docs/zh/README.md b/docs/zh/README.md index b0b2310..edcef50 100644 --- a/docs/zh/README.md +++ b/docs/zh/README.md @@ -31,6 +31,7 @@ graph LR; * [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md) * [模型训练](/docs/zh/Pipeline_Usage/Model_Training.md) * [环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md) +* [GPU/NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md) ## Section 2: 模型详解