mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
[model][NPU] Add NPU fusion operator patch to Zimage model to improve performance
This commit is contained in:
@@ -13,4 +13,5 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_
|
||||
--output_path "./models/train/Z-Image-Turbo_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8
|
||||
--dataset_num_workers 8 \
|
||||
--enable_npu_patch
|
||||
|
||||
@@ -20,12 +20,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
enable_npu_patch=True,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
# Training mode
|
||||
@@ -94,6 +95,7 @@ def z_image_parser():
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||
parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -136,6 +138,7 @@ if __name__ == "__main__":
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device=accelerator.device,
|
||||
enable_npu_patch=args.enable_npu_patch
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
|
||||
Reference in New Issue
Block a user