[model][NPU] Add NPU fusion operator patch to Zimage model to improve performance

This commit is contained in:
feng0w0
2026-02-03 19:50:21 +08:00
parent ffb7a138f7
commit 051b957adb
5 changed files with 66 additions and 13 deletions

View File

@@ -6,7 +6,7 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type
from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
@@ -63,6 +63,7 @@ class ZImagePipeline(BasePipeline):
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
enable_npu_patch: bool = True,
):
# Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
@@ -84,6 +85,8 @@ class ZImagePipeline(BasePipeline):
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
# NPU patch
apply_npu_patch(enable_npu_patch)
return pipe
@@ -667,3 +670,19 @@ def model_fn_z_image_turbo(
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x
def apply_npu_patch(enable_npu_patch: bool=True):
if IS_NPU_AVAILABLE and enable_npu_patch:
from ..models.general_modules import RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)
RMSNorm.forward = rms_norm_forward_npu
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu