mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -11,6 +11,7 @@ from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
from transformers import Wav2Vec2Processor
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
||||
|
||||
class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
audio_processor_config: ModelConfig = None,
|
||||
@@ -960,7 +961,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
|
||||
if mask_pixel_values is None:
|
||||
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user