mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1256 from Feng0w0/npu_fused
[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance
This commit is contained in:
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from ..device.npu_compatible_device import get_device_type
|
||||
try:
|
||||
import torch_npu
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def rms_norm_forward_npu(self, hidden_states):
|
||||
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
|
||||
if hidden_states.dtype != self.weight.dtype:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
|
||||
|
||||
|
||||
def rms_norm_forward_transformers_npu(self, hidden_states):
|
||||
"npu rms fused operator for transformers"
|
||||
if hidden_states.dtype != self.weight.dtype:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
|
||||
|
||||
|
||||
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
||||
"npu rope fused operator for Zimage"
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
|
||||
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
|
||||
@@ -88,6 +88,14 @@ class Attention(torch.nn.Module):
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
def forward(self, hidden_states, freqs_cis, attention_mask):
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
@@ -103,17 +111,9 @@ class Attention(torch.nn.Module):
|
||||
if self.norm_k is not None:
|
||||
key = self.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
query = self.apply_rotary_emb(query, freqs_cis)
|
||||
key = self.apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, math
|
||||
import torch, math, warnings
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
@@ -6,7 +6,7 @@ from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
@@ -63,6 +63,7 @@ class ZImagePipeline(BasePipeline):
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
vram_limit: float = None,
|
||||
enable_npu_patch: bool = True,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
@@ -84,6 +85,8 @@ class ZImagePipeline(BasePipeline):
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
# NPU patch
|
||||
apply_npu_patch(enable_npu_patch)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -667,3 +670,19 @@ def model_fn_z_image_turbo(
|
||||
x = rearrange(x, "C B H W -> B C H W")
|
||||
x = -x
|
||||
return x
|
||||
|
||||
|
||||
def apply_npu_patch(enable_npu_patch: bool=True):
|
||||
if IS_NPU_AVAILABLE and enable_npu_patch:
|
||||
from ..models.general_modules import RMSNorm
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
|
||||
from ..models.z_image_dit import Attention
|
||||
from ..core.npu_patch.npu_fused_operator import (
|
||||
rms_norm_forward_npu,
|
||||
rms_norm_forward_transformers_npu,
|
||||
rotary_emb_Zimage_npu
|
||||
)
|
||||
warnings.warn("Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.")
|
||||
RMSNorm.forward = rms_norm_forward_npu
|
||||
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
|
||||
Attention.apply_rotary_emb = rotary_emb_Zimage_npu
|
||||
|
||||
@@ -13,4 +13,5 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_
|
||||
--output_path "./models/train/Z-Image-Turbo_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8
|
||||
--dataset_num_workers 8 \
|
||||
--enable_npu_patch
|
||||
|
||||
@@ -20,12 +20,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
enable_npu_patch=True,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
# Training mode
|
||||
@@ -94,6 +95,7 @@ def z_image_parser():
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||
parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -136,6 +138,7 @@ if __name__ == "__main__":
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device=accelerator.device,
|
||||
enable_npu_patch=args.enable_npu_patch
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
|
||||
Reference in New Issue
Block a user