mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||
import torch
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
|
||||
|
||||
class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||
def __init__(self):
|
||||
@@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
||||
bool_masked_pos = None
|
||||
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -373,7 +374,9 @@ class FinalLayer_FP32(nn.Module):
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast('cuda', dtype=torch.float32):
|
||||
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
@@ -583,7 +586,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
@@ -602,7 +607,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
else:
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -615,7 +622,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
# ffn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -797,7 +806,9 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
|
||||
is_compileable = is_compileable and not self.generation_config.disable_compile
|
||||
if is_compileable and (
|
||||
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
|
||||
self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
|
||||
):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
@@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
|
||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
||||
import torch
|
||||
|
||||
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||
|
||||
|
||||
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
def __init__(self):
|
||||
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
|
||||
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
|
||||
output_attentions = False
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
from .qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
|
||||
|
||||
|
||||
class Step1xEditEmbedder(torch.nn.Module):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.dtype = dtype
|
||||
@@ -77,13 +78,13 @@ User Prompt:'''
|
||||
self.max_length,
|
||||
self.model.config.hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
device=get_torch_device().current_device(),
|
||||
)
|
||||
masks = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
device=get_torch_device().current_device(),
|
||||
)
|
||||
|
||||
def split_string(s):
|
||||
@@ -158,7 +159,7 @@ User Prompt:'''
|
||||
else:
|
||||
token_list.append(token_each)
|
||||
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
|
||||
|
||||
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||
|
||||
@@ -167,15 +168,15 @@ User Prompt:'''
|
||||
inputs.input_ids = (
|
||||
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||
.unsqueeze(0)
|
||||
.to("cuda")
|
||||
.to(get_device_type())
|
||||
)
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
|
||||
outputs = self.model_forward(
|
||||
self.model,
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pixel_values=inputs.pixel_values.to("cuda"),
|
||||
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||
pixel_values=inputs.pixel_values.to(get_device_type()),
|
||||
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
@@ -188,7 +189,7 @@ User Prompt:'''
|
||||
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||
(min(self.max_length, emb.shape[1] - 217)),
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
device=get_torch_device().current_device(),
|
||||
)
|
||||
|
||||
return embs, masks
|
||||
|
||||
@@ -94,7 +94,6 @@ def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class TimestepEmbedder(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
@@ -105,7 +105,7 @@ class Attention(torch.nn.Module):
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
|
||||
Reference in New Issue
Block a user