mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1256 from Feng0w0/npu_fused
[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance
This commit is contained in:
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from ..device.npu_compatible_device import get_device_type
|
||||
try:
|
||||
import torch_npu
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def rms_norm_forward_npu(self, hidden_states):
|
||||
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
|
||||
if hidden_states.dtype != self.weight.dtype:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
|
||||
|
||||
|
||||
def rms_norm_forward_transformers_npu(self, hidden_states):
|
||||
"npu rms fused operator for transformers"
|
||||
if hidden_states.dtype != self.weight.dtype:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
|
||||
|
||||
|
||||
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
||||
"npu rope fused operator for Zimage"
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
|
||||
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
|
||||
@@ -88,6 +88,14 @@ class Attention(torch.nn.Module):
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
def forward(self, hidden_states, freqs_cis, attention_mask):
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
@@ -103,17 +111,9 @@ class Attention(torch.nn.Module):
|
||||
if self.norm_k is not None:
|
||||
key = self.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
query = self.apply_rotary_emb(query, freqs_cis)
|
||||
key = self.apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, math
|
||||
import torch, math, warnings
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
@@ -6,7 +6,7 @@ from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
@@ -63,6 +63,7 @@ class ZImagePipeline(BasePipeline):
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
vram_limit: float = None,
|
||||
enable_npu_patch: bool = True,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
@@ -84,6 +85,8 @@ class ZImagePipeline(BasePipeline):
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
# NPU patch
|
||||
apply_npu_patch(enable_npu_patch)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -667,3 +670,19 @@ def model_fn_z_image_turbo(
|
||||
x = rearrange(x, "C B H W -> B C H W")
|
||||
x = -x
|
||||
return x
|
||||
|
||||
|
||||
def apply_npu_patch(enable_npu_patch: bool=True):
|
||||
if IS_NPU_AVAILABLE and enable_npu_patch:
|
||||
from ..models.general_modules import RMSNorm
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
|
||||
from ..models.z_image_dit import Attention
|
||||
from ..core.npu_patch.npu_fused_operator import (
|
||||
rms_norm_forward_npu,
|
||||
rms_norm_forward_transformers_npu,
|
||||
rotary_emb_Zimage_npu
|
||||
)
|
||||
warnings.warn("Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.")
|
||||
RMSNorm.forward = rms_norm_forward_npu
|
||||
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
|
||||
Attention.apply_rotary_emb = rotary_emb_Zimage_npu
|
||||
|
||||
Reference in New Issue
Block a user