[NPU]:Replace 'cuda' in the project with abstract interfaces

This commit is contained in:
feng0w0
2026-01-15 20:33:01 +08:00
parent f4d06ce3fc
commit 209a350c0f
18 changed files with 86 additions and 38 deletions

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -18,7 +19,7 @@ from ..models.flux2_vae import Flux2VAE
class Flux2ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -42,7 +43,7 @@ class Flux2ImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None,

View File

@@ -6,6 +6,7 @@ from einops import rearrange, repeat
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from math import prod
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -22,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -60,7 +61,7 @@ class QwenImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,

View File

@@ -11,6 +11,7 @@ from typing import Optional
from typing_extensions import Literal
from transformers import Wav2Vec2Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None,
@@ -960,7 +961,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
onload_model_names=("vae",)
)
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
@@ -25,7 +26,7 @@ from ..models.z_image_image2lora import ZImageImage2LoRAModel
class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -58,7 +59,7 @@ class ZImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,