mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
|
||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
||||
import torch
|
||||
|
||||
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||
|
||||
|
||||
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
def __init__(self):
|
||||
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
|
||||
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
|
||||
output_attentions = False
|
||||
|
||||
Reference in New Issue
Block a user