[NPU]:Replace 'cuda' in the project with abstract interfaces

This commit is contained in:
feng0w0
2026-01-15 20:04:54 +08:00
parent f4d06ce3fc
commit a3c2744a43
16 changed files with 86 additions and 36 deletions

View File

@@ -4,3 +4,4 @@ from .gradient import *
from .loader import *
from .vram import *
from .device import *
from .npu_patch import *

View File

@@ -0,0 +1,5 @@
from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE
from .npu_autocast_patch import npu_autocast_patch
if IS_NPU_AVAILABLE:
npu_autocast_patch()

View File

@@ -0,0 +1,21 @@
import torch
from contextlib import contextmanager
def npu_autocast_patch_wrapper(func):
@contextmanager
def wrapper(*args, **kwargs):
flag = False
if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"):
if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32):
flag = True
with func(*args, **kwargs) as ctx:
if flag:
torch.npu.set_autocast_enabled(True)
yield ctx
return wrapper
def npu_autocast_patch():
torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast)
torch.autocast = npu_autocast_patch_wrapper(torch.autocast)

View File

@@ -4,6 +4,7 @@ import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
@@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda", torch_dtype=torch.float16,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):

View File

@@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
@@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None

View File

@@ -9,6 +9,7 @@ import numpy as np
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward
@@ -373,7 +374,9 @@ class FinalLayer_FP32(nn.Module):
B, N, C = x.shape
T, _, _ = latent_shape
with amp.autocast('cuda', dtype=torch.float32):
with amp.autocast(get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
@@ -583,7 +586,9 @@ class LongCatSingleStreamBlock(nn.Module):
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
@@ -602,7 +607,9 @@ class LongCatSingleStreamBlock(nn.Module):
else:
x_s = attn_outputs
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
@@ -615,7 +622,9 @@ class LongCatSingleStreamBlock(nn.Module):
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
@@ -797,7 +806,9 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]

View File

@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)

View File

@@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False

View File

@@ -1,10 +1,11 @@
import torch
from typing import Optional, Union
from .qwen_image_text_encoder import QwenImageTextEncoder
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
class Step1xEditEmbedder(torch.nn.Module):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
super().__init__()
self.max_length = max_length
self.dtype = dtype
@@ -77,13 +78,13 @@ User Prompt:'''
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
)
masks = torch.zeros(
len(text_list),
self.max_length,
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
)
def split_string(s):
@@ -158,7 +159,7 @@ User Prompt:'''
else:
token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
@@ -167,15 +168,15 @@ User Prompt:'''
inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0)
.to("cuda")
.to(get_device_type())
)
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
outputs = self.model_forward(
self.model,
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"),
image_grid_thw=inputs.image_grid_thw.to("cuda"),
pixel_values=inputs.pixel_values.to(get_device_type()),
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
output_hidden_states=True,
)
@@ -188,7 +189,7 @@ User Prompt:'''
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
)
return embs, masks

View File

@@ -8,7 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward
@@ -40,7 +40,7 @@ class TimestepEmbedder(nn.Module):
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_device_type(), enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
@@ -105,7 +105,7 @@ class Attention(torch.nn.Module):
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -18,7 +19,7 @@ from ..models.flux2_vae import Flux2VAE
class Flux2ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -42,7 +43,7 @@ class Flux2ImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None,

View File

@@ -6,6 +6,7 @@ from einops import rearrange, repeat
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from math import prod
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -22,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -60,7 +61,7 @@ class QwenImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,

View File

@@ -11,6 +11,7 @@ from typing import Optional
from typing_extensions import Literal
from transformers import Wav2Vec2Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None,
@@ -960,7 +961,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
onload_model_names=("vae",)
)
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
@@ -25,7 +26,7 @@ from ..models.z_image_image2lora import ZImageImage2LoRAModel
class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -58,7 +59,7 @@ class ZImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,

View File

@@ -1,12 +1,13 @@
from typing_extensions import Literal, TypeAlias
from diffsynth.core.device.npu_compatible_device import get_device_type
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
if not skip_processor:
if processor_id == "canny":
from controlnet_aux.processor import CannyDetector