mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
from .qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
|
||||
|
||||
|
||||
class Step1xEditEmbedder(torch.nn.Module):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.dtype = dtype
|
||||
@@ -77,13 +78,13 @@ User Prompt:'''
|
||||
self.max_length,
|
||||
self.model.config.hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
device=get_torch_device().current_device(),
|
||||
)
|
||||
masks = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
device=get_torch_device().current_device(),
|
||||
)
|
||||
|
||||
def split_string(s):
|
||||
@@ -158,7 +159,7 @@ User Prompt:'''
|
||||
else:
|
||||
token_list.append(token_each)
|
||||
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
|
||||
|
||||
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||
|
||||
@@ -167,15 +168,15 @@ User Prompt:'''
|
||||
inputs.input_ids = (
|
||||
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||
.unsqueeze(0)
|
||||
.to("cuda")
|
||||
.to(get_device_type())
|
||||
)
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
|
||||
outputs = self.model_forward(
|
||||
self.model,
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pixel_values=inputs.pixel_values.to("cuda"),
|
||||
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||
pixel_values=inputs.pixel_values.to(get_device_type()),
|
||||
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
@@ -188,7 +189,7 @@ User Prompt:'''
|
||||
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||
(min(self.max_length, emb.shape[1] - 217)),
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
device=get_torch_device().current_device(),
|
||||
)
|
||||
|
||||
return embs, masks
|
||||
|
||||
Reference in New Issue
Block a user