From 209a350c0f6dcaccfcd5be3c3b73e28b8813bed0 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Thu, 15 Jan 2026 20:33:01 +0800 Subject: [PATCH] [NPU]:Replace 'cuda' in the project with abstract interfaces --- diffsynth/core/__init__.py | 1 + diffsynth/core/npu_patch/__init__.py | 5 +++++ .../core/npu_patch/npu_autocast_patch.py | 21 +++++++++++++++++++ diffsynth/diffusion/base_pipeline.py | 3 ++- diffsynth/models/dinov3_image_encoder.py | 4 +++- diffsynth/models/longcat_video_dit.py | 21 ++++++++++++++----- diffsynth/models/nexus_gen_ar_model.py | 2 +- diffsynth/models/siglip2_image_encoder.py | 4 +++- diffsynth/models/step1x_text_encoder.py | 19 +++++++++-------- diffsynth/models/wan_video_dit.py | 1 - diffsynth/models/z_image_dit.py | 6 +++--- diffsynth/pipelines/flux2_image.py | 5 +++-- diffsynth/pipelines/flux_image.py | 11 +++++----- diffsynth/pipelines/qwen_image.py | 5 +++-- diffsynth/pipelines/wan_video.py | 7 ++++--- diffsynth/pipelines/z_image.py | 5 +++-- diffsynth/utils/controlnet/annotator.py | 3 ++- .../utils/xfuser/xdit_context_parallel.py | 1 - 18 files changed, 86 insertions(+), 38 deletions(-) create mode 100644 diffsynth/core/npu_patch/__init__.py create mode 100644 diffsynth/core/npu_patch/npu_autocast_patch.py diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py index 6c0a6c8..4d5f440 100644 --- a/diffsynth/core/__init__.py +++ b/diffsynth/core/__init__.py @@ -4,3 +4,4 @@ from .gradient import * from .loader import * from .vram import * from .device import * +from .npu_patch import * diff --git a/diffsynth/core/npu_patch/__init__.py b/diffsynth/core/npu_patch/__init__.py new file mode 100644 index 0000000..eb1df93 --- /dev/null +++ b/diffsynth/core/npu_patch/__init__.py @@ -0,0 +1,5 @@ +from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE +from .npu_autocast_patch import npu_autocast_patch + +if IS_NPU_AVAILABLE: + npu_autocast_patch() diff --git a/diffsynth/core/npu_patch/npu_autocast_patch.py b/diffsynth/core/npu_patch/npu_autocast_patch.py new file mode 100644 index 0000000..08b1caf --- /dev/null +++ b/diffsynth/core/npu_patch/npu_autocast_patch.py @@ -0,0 +1,21 @@ +import torch +from contextlib import contextmanager + + +def npu_autocast_patch_wrapper(func): + @contextmanager + def wrapper(*args, **kwargs): + flag = False + if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"): + if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32): + flag = True + with func(*args, **kwargs) as ctx: + if flag: + torch.npu.set_autocast_enabled(True) + yield ctx + return wrapper + + +def npu_autocast_patch(): + torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast) + torch.autocast = npu_autocast_patch_wrapper(torch.autocast) diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 4fe1559..d4731fd 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -4,6 +4,7 @@ import numpy as np from einops import repeat, reduce from typing import Union from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type +from ..core.device.npu_compatible_device import get_device_type from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput @@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module): def __init__( self, - device="cuda", torch_dtype=torch.float16, + device=get_device_type(), torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64, time_division_factor=None, time_division_remainder=None, ): diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index be2ee58..c394a03 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig import torch +from ..core.device.npu_compatible_device import get_device_type + class DINOv3ImageEncoder(DINOv3ViTModel): def __init__(self): @@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) bool_masked_pos = None diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py index 6d65723..ebcc9d0 100644 --- a/diffsynth/models/longcat_video_dit.py +++ b/diffsynth/models/longcat_video_dit.py @@ -9,6 +9,7 @@ import numpy as np import torch.nn.functional as F from einops import rearrange, repeat from .wan_video_dit import flash_attention +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -373,7 +374,9 @@ class FinalLayer_FP32(nn.Module): B, N, C = x.shape T, _, _ = latent_shape - with amp.autocast('cuda', dtype=torch.float32): + with amp.autocast(get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = self.linear(x) @@ -583,7 +586,9 @@ class LongCatSingleStreamBlock(nn.Module): T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. # compute modulation params in fp32 - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] @@ -602,7 +607,9 @@ class LongCatSingleStreamBlock(nn.Module): else: x_s = attn_outputs - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -615,7 +622,9 @@ class LongCatSingleStreamBlock(nn.Module): # ffn with modulation x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_s = self.ffn(x_m) - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -797,7 +806,9 @@ class LongCatVideoTransformer3DModel(torch.nn.Module): hidden_states = self.x_embedder(hidden_states) # [B, N, C] - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py index d5a2973..b647786 100644 --- a/diffsynth/models/nexus_gen_ar_model.py +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = is_compileable and not self.generation_config.disable_compile if is_compileable and ( - self.device.type == "cuda" or generation_config.compile_config._compile_all_devices + self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices ): os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 87df855..509eff4 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast import torch +from diffsynth.core.device.npu_compatible_device import get_device_type + class Siglip2ImageEncoder(SiglipVisionTransformer): def __init__(self): @@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device=device, dtype=torch_dtype) output_attentions = False diff --git a/diffsynth/models/step1x_text_encoder.py b/diffsynth/models/step1x_text_encoder.py index d0fe221..5d14423 100644 --- a/diffsynth/models/step1x_text_encoder.py +++ b/diffsynth/models/step1x_text_encoder.py @@ -1,10 +1,11 @@ import torch from typing import Optional, Union from .qwen_image_text_encoder import QwenImageTextEncoder +from ..core.device.npu_compatible_device import get_device_type, get_torch_device class Step1xEditEmbedder(torch.nn.Module): - def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): + def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()): super().__init__() self.max_length = max_length self.dtype = dtype @@ -77,13 +78,13 @@ User Prompt:''' self.max_length, self.model.config.hidden_size, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) masks = torch.zeros( len(text_list), self.max_length, dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) def split_string(s): @@ -158,7 +159,7 @@ User Prompt:''' else: token_list.append(token_each) - new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type()) new_txt_ids = new_txt_ids.to(old_inputs_ids.device) @@ -167,15 +168,15 @@ User Prompt:''' inputs.input_ids = ( torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) .unsqueeze(0) - .to("cuda") + .to(get_device_type()) ) - inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type()) outputs = self.model_forward( self.model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, - pixel_values=inputs.pixel_values.to("cuda"), - image_grid_thw=inputs.image_grid_thw.to("cuda"), + pixel_values=inputs.pixel_values.to(get_device_type()), + image_grid_thw=inputs.image_grid_thw.to(get_device_type()), output_hidden_states=True, ) @@ -188,7 +189,7 @@ User Prompt:''' masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( (min(self.max_length, emb.shape[1] - 217)), dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) return embs, masks diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 43cd601..cfee258 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -94,7 +94,6 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index f157f38..bb49067 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -8,7 +8,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.nn import RMSNorm from ..core.attention import attention_forward -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -40,7 +40,7 @@ class TimestepEmbedder(nn.Module): @staticmethod def timestep_embedding(t, dim, max_period=10000): - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half @@ -105,7 +105,7 @@ class Attention(torch.nn.Module): # Apply RoPE def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x * freqs_cis).flatten(3) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 8b00469..5ecbb20 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -6,6 +6,7 @@ from einops import rearrange import numpy as np from typing import Union, List, Optional, Tuple +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -18,7 +19,7 @@ from ..models.flux2_vae import Flux2VAE class Flux2ImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -42,7 +43,7 @@ class Flux2ImagePipeline(BasePipeline): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 1ee5635..bfc53e5 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -6,6 +6,7 @@ from einops import rearrange, repeat import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module): class FluxImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), @@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit): text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit): text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit): class InfinitYou(torch.nn.Module): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__() from facexlib.recognition import init_recognition_model from insightface.app import FaceAnalysis diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 4bfa00e..75cfbee 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -6,6 +6,7 @@ from einops import rearrange import numpy as np from math import prod +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -22,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel class QwenImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -60,7 +61,7 @@ class QwenImagePipeline(BasePipeline): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), processor_config: ModelConfig = None, diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index ca59d2a..866ac18 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -11,6 +11,7 @@ from typing import Optional from typing_extensions import Literal from transformers import Wav2Vec2Processor +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit @@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel class WanVideoPipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), audio_processor_config: ModelConfig = None, @@ -960,7 +961,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit): onload_model_names=("vae",) ) - def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()): if mask_pixel_values is None: msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) else: diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 9ba182a..2c5b687 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -6,6 +6,7 @@ from einops import rearrange import numpy as np from typing import Union, List, Optional, Tuple, Iterable, Dict +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..core.data.operators import ImageCropAndResize @@ -25,7 +26,7 @@ from ..models.z_image_image2lora import ZImageImage2LoRAModel class ZImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -58,7 +59,7 @@ class ZImagePipeline(BasePipeline): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/utils/controlnet/annotator.py b/diffsynth/utils/controlnet/annotator.py index 06553e0..cb73738 100644 --- a/diffsynth/utils/controlnet/annotator.py +++ b/diffsynth/utils/controlnet/annotator.py @@ -1,12 +1,13 @@ from typing_extensions import Literal, TypeAlias +from diffsynth.core.device.npu_compatible_device import get_device_type Processor_id: TypeAlias = Literal[ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" ] class Annotator: - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False): if not skip_processor: if processor_id == "canny": from controlnet_aux.processor import CannyDetector diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index d365cfe..4a1cd14 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -50,7 +50,6 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype)