Merge pull request #1256 from Feng0w0/npu_fused

[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance
This commit is contained in:
Zhongjie Duan
2026-02-09 20:08:44 +08:00
committed by GitHub
5 changed files with 67 additions and 14 deletions

View File

@@ -13,4 +13,5 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_
--output_path "./models/train/Z-Image-Turbo_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--dataset_num_workers 8
--dataset_num_workers 8 \
--enable_npu_patch

View File

@@ -20,12 +20,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
offload_models=None,
device="cpu",
task="sft",
enable_npu_patch=True,
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
@@ -94,6 +95,7 @@ def z_image_parser():
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.")
return parser
@@ -136,6 +138,7 @@ if __name__ == "__main__":
offload_models=args.offload_models,
task=args.task,
device=accelerator.device,
enable_npu_patch=args.enable_npu_patch
)
model_logger = ModelLogger(
args.output_path,