Merge pull request #1256 from Feng0w0/npu_fused

[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance
This commit is contained in:
Zhongjie Duan
2026-02-09 20:08:44 +08:00
committed by GitHub
5 changed files with 67 additions and 14 deletions

View File

@@ -0,0 +1,30 @@
import torch
from ..device.npu_compatible_device import get_device_type
try:
import torch_npu
except:
pass
def rms_norm_forward_npu(self, hidden_states):
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
def rms_norm_forward_transformers_npu(self, hidden_states):
"npu rms fused operator for transformers"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
"npu rope fused operator for Zimage"
with torch.amp.autocast(get_device_type(), enabled=False):
freqs_cis = freqs_cis.unsqueeze(2)
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)

View File

@@ -88,6 +88,14 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5) self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5) self.norm_k = RMSNorm(head_dim, eps=1e-5)
# Apply RoPE
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
def forward(self, hidden_states, freqs_cis, attention_mask): def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
key = self.to_k(hidden_states) key = self.to_k(hidden_states)
@@ -103,17 +111,9 @@ class Attention(torch.nn.Module):
if self.norm_k is not None: if self.norm_k is not None:
key = self.norm_k(key) key = self.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None: if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis) query = self.apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis) key = self.apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype # Cast to correct dtype
dtype = query.dtype dtype = query.dtype

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, warnings
from PIL import Image from PIL import Image
from typing import Union from typing import Union
from tqdm import tqdm from tqdm import tqdm
@@ -6,7 +6,7 @@ from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize from ..core.data.operators import ImageCropAndResize
@@ -63,6 +63,7 @@ class ZImagePipeline(BasePipeline):
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,
enable_npu_patch: bool = True,
): ):
# Initialize pipeline # Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
@@ -84,6 +85,8 @@ class ZImagePipeline(BasePipeline):
# VRAM Management # VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state() pipe.vram_management_enabled = pipe.check_vram_management_state()
# NPU patch
apply_npu_patch(enable_npu_patch)
return pipe return pipe
@@ -667,3 +670,19 @@ def model_fn_z_image_turbo(
x = rearrange(x, "C B H W -> B C H W") x = rearrange(x, "C B H W -> B C H W")
x = -x x = -x
return x return x
def apply_npu_patch(enable_npu_patch: bool=True):
if IS_NPU_AVAILABLE and enable_npu_patch:
from ..models.general_modules import RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)
warnings.warn("Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.")
RMSNorm.forward = rms_norm_forward_npu
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu

View File

@@ -13,4 +13,5 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_
--output_path "./models/train/Z-Image-Turbo_full" \ --output_path "./models/train/Z-Image-Turbo_full" \
--trainable_models "dit" \ --trainable_models "dit" \
--use_gradient_checkpointing \ --use_gradient_checkpointing \
--dataset_num_workers 8 --dataset_num_workers 8 \
--enable_npu_patch

View File

@@ -20,12 +20,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
offload_models=None, offload_models=None,
device="cpu", device="cpu",
task="sft", task="sft",
enable_npu_patch=True,
): ):
super().__init__() super().__init__()
# Load models # Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode # Training mode
@@ -94,6 +95,7 @@ def z_image_parser():
parser = add_general_config(parser) parser = add_general_config(parser)
parser = add_image_size_config(parser) parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.")
return parser return parser
@@ -136,6 +138,7 @@ if __name__ == "__main__":
offload_models=args.offload_models, offload_models=args.offload_models,
task=args.task, task=args.task,
device=accelerator.device, device=accelerator.device,
enable_npu_patch=args.enable_npu_patch
) )
model_logger = ModelLogger( model_logger = ModelLogger(
args.output_path, args.output_path,