Merge pull request #1207 from Feng0w0/cuda_replace

[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
Zhongjie Duan
2026-01-20 10:13:04 +08:00
committed by GitHub
13 changed files with 49 additions and 36 deletions

View File

@@ -4,6 +4,7 @@ import numpy as np
from einops import repeat, reduce from einops import repeat, reduce
from typing import Union from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput from ..utils.controlnet import ControlNetInput
@@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module):
def __init__( def __init__(
self, self,
device="cuda", torch_dtype=torch.float16, device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64, height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None, time_division_factor=None, time_division_remainder=None,
): ):

View File

@@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel): class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self): def __init__(self):
@@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
} }
) )
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt") inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None bool_masked_pos = None

View File

@@ -9,6 +9,7 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from .wan_video_dit import flash_attention from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward from ..core.gradient import gradient_checkpoint_forward
@@ -373,7 +374,7 @@ class FinalLayer_FP32(nn.Module):
B, N, C = x.shape B, N, C = x.shape
T, _, _ = latent_shape T, _, _ = latent_shape
with amp.autocast('cuda', dtype=torch.float32): with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x) x = self.linear(x)
@@ -583,7 +584,7 @@ class LongCatSingleStreamBlock(nn.Module):
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32 # compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \ shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \ shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
@@ -602,7 +603,7 @@ class LongCatSingleStreamBlock(nn.Module):
else: else:
x_s = attn_outputs x_s = attn_outputs
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype) x = x.to(x_dtype)
@@ -615,7 +616,7 @@ class LongCatSingleStreamBlock(nn.Module):
# ffn with modulation # ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m) x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype) x = x.to(x_dtype)
@@ -797,7 +798,7 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
hidden_states = self.x_embedder(hidden_states) # [B, N, C] hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]

View File

@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile is_compileable = is_compileable and not self.generation_config.disable_compile
if is_compileable and ( if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
): ):
os.environ["TOKENIZERS_PARALLELISM"] = "0" os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config) model_forward = self.get_compiled_call(generation_config.compile_config)

View File

@@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
class Siglip2ImageEncoder(SiglipVisionTransformer): class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self): def __init__(self):
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
} }
) )
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype) pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False output_attentions = False

View File

@@ -1,10 +1,11 @@
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from .qwen_image_text_encoder import QwenImageTextEncoder from .qwen_image_text_encoder import QwenImageTextEncoder
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
class Step1xEditEmbedder(torch.nn.Module): class Step1xEditEmbedder(torch.nn.Module):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
super().__init__() super().__init__()
self.max_length = max_length self.max_length = max_length
self.dtype = dtype self.dtype = dtype
@@ -77,13 +78,13 @@ User Prompt:'''
self.max_length, self.max_length,
self.model.config.hidden_size, self.model.config.hidden_size,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
masks = torch.zeros( masks = torch.zeros(
len(text_list), len(text_list),
self.max_length, self.max_length,
dtype=torch.long, dtype=torch.long,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
def split_string(s): def split_string(s):
@@ -158,7 +159,7 @@ User Prompt:'''
else: else:
token_list.append(token_each) token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda") new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
new_txt_ids = new_txt_ids.to(old_inputs_ids.device) new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
@@ -167,15 +168,15 @@ User Prompt:'''
inputs.input_ids = ( inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0) .unsqueeze(0)
.to("cuda") .to(get_device_type())
) )
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
outputs = self.model_forward( outputs = self.model_forward(
self.model, self.model,
input_ids=inputs.input_ids, input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask, attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"), pixel_values=inputs.pixel_values.to(get_device_type()),
image_grid_thw=inputs.image_grid_thw.to("cuda"), image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
output_hidden_states=True, output_hidden_states=True,
) )
@@ -188,7 +189,7 @@ User Prompt:'''
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)), (min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long, dtype=torch.long,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
return embs, masks return embs, masks

View File

@@ -8,7 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
from .general_modules import RMSNorm from .general_modules import RMSNorm
from ..core.attention import attention_forward from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward from ..core.gradient import gradient_checkpoint_forward
@@ -40,7 +40,7 @@ class TimestepEmbedder(nn.Module):
@staticmethod @staticmethod
def timestep_embedding(t, dim, max_period=10000): def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False): with torch.amp.autocast(get_device_type(), enabled=False):
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
@@ -105,7 +105,7 @@ class Attention(torch.nn.Module):
# Apply RoPE # Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False): with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2) freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3) x_out = torch.view_as_real(x * freqs_cis).flatten(3)

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple from typing import Union, List, Optional, Tuple
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -19,7 +20,7 @@ from ..models.z_image_text_encoder import ZImageTextEncoder
class Flux2ImagePipeline(BasePipeline): class Flux2ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -45,7 +46,7 @@ class Flux2ImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,

View File

@@ -6,6 +6,7 @@ from einops import rearrange, repeat
import numpy as np import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast from transformers import CLIPTokenizer, T5TokenizerFast
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
class FluxImagePipeline(BasePipeline): class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
text_encoder_2, text_encoder_2,
prompt, prompt,
positive=True, positive=True,
device="cuda", device=get_device_type(),
t5_sequence_length=512, t5_sequence_length=512,
): ):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
text_encoder_2, text_encoder_2,
prompt, prompt,
positive=True, positive=True,
device="cuda", device=get_device_type(),
t5_sequence_length=512, t5_sequence_length=512,
): ):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
class InfinitYou(torch.nn.Module): class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__() super().__init__()
from facexlib.recognition import init_recognition_model from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np import numpy as np
from math import prod from math import prod
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -22,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
class QwenImagePipeline(BasePipeline): class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -60,7 +61,7 @@ class QwenImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None, processor_config: ModelConfig = None,

View File

@@ -11,6 +11,7 @@ from typing import Optional
from typing_extensions import Literal from typing_extensions import Literal
from transformers import Wav2Vec2Processor from transformers import Wav2Vec2Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
class WanVideoPipeline(BasePipeline): class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None, audio_processor_config: ModelConfig = None,
@@ -964,7 +965,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
onload_model_names=("vae",) onload_model_names=("vae",)
) )
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None: if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else: else:

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize from ..core.data.operators import ImageCropAndResize
@@ -25,7 +26,7 @@ from ..models.z_image_image2lora import ZImageImage2LoRAModel
class ZImagePipeline(BasePipeline): class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -58,7 +59,7 @@ class ZImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,

View File

@@ -1,12 +1,13 @@
from typing_extensions import Literal, TypeAlias from typing_extensions import Literal, TypeAlias
from diffsynth.core.device.npu_compatible_device import get_device_type
Processor_id: TypeAlias = Literal[ Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
] ]
class Annotator: class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
if not skip_processor: if not skip_processor:
if processor_id == "canny": if processor_id == "canny":
from controlnet_aux.processor import CannyDetector from controlnet_aux.processor import CannyDetector