[NPU]:Replace 'cuda' in the project with abstract interfaces

This commit is contained in:
feng0w0
2026-01-15 20:33:01 +08:00
parent f4d06ce3fc
commit 209a350c0f
18 changed files with 86 additions and 38 deletions

View File

@@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False