mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -6,6 +6,7 @@ from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
|
||||
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
|
||||
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
|
||||
text_encoder_2,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
device=get_device_type(),
|
||||
t5_sequence_length=512,
|
||||
):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
||||
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
text_encoder_2,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
device=get_device_type(),
|
||||
t5_sequence_length=512,
|
||||
):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
||||
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
|
||||
|
||||
|
||||
class InfinitYou(torch.nn.Module):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
from facexlib.recognition import init_recognition_model
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
Reference in New Issue
Block a user