From 4e9db263b0ec953dc3f12833cd343d09b20441dc Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Tue, 27 Jan 2026 11:24:43 +0800 Subject: [PATCH] [feature]:Add adaptation of all models to zero3 --- .../core/gradient/gradient_checkpoint.py | 2 + diffsynth/core/loader/model.py | 49 ++++++++++++++----- diffsynth/diffusion/logger.py | 4 +- diffsynth/diffusion/runner.py | 3 +- diffsynth/models/wan_video_vae.py | 36 +++++++------- diffsynth/pipelines/flux2_image.py | 2 +- .../full/accelerate_config_zero3.yaml | 23 +++++++++ .../full/accelerate_config_zero3.yaml | 23 +++++++++ .../npu_training/FLUX.2-dev-Lora-NPU.sh | 36 ++++++++++++++ .../npu_training/FLUX.2-klein-9B-NPU.sh | 34 +++++++++++++ examples/flux2/model_training/train.py | 3 +- .../full/accelerate_config_zero3.yaml | 23 +++++++++ .../npu_training/Qwen-Image-Edit-2509-NPU.sh | 16 ++++++ .../full/accelerate_config_zero3.yaml | 23 +++++++++ .../full/accelerate_config_zero3.yaml | 23 +++++++++ 15 files changed, 266 insertions(+), 34 deletions(-) create mode 100644 examples/flux/model_training/full/accelerate_config_zero3.yaml create mode 100644 examples/flux2/model_training/full/accelerate_config_zero3.yaml create mode 100644 examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh create mode 100644 examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh create mode 100644 examples/qwen_image/model_training/full/accelerate_config_zero3.yaml create mode 100644 examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh create mode 100644 examples/wanvideo/model_training/full/accelerate_config_zero3.yaml create mode 100644 examples/z_image/model_training/full/accelerate_config_zero3.yaml diff --git a/diffsynth/core/gradient/gradient_checkpoint.py b/diffsynth/core/gradient/gradient_checkpoint.py index b356415..d252573 100644 --- a/diffsynth/core/gradient/gradient_checkpoint.py +++ b/diffsynth/core/gradient/gradient_checkpoint.py @@ -21,6 +21,7 @@ def gradient_checkpoint_forward( *args, **kwargs, use_reentrant=False, + determinism_check="none" ) elif use_gradient_checkpointing: model_output = torch.utils.checkpoint.checkpoint( @@ -28,6 +29,7 @@ def gradient_checkpoint_forward( *args, **kwargs, use_reentrant=False, + determinism_check="none" ) else: model_output = model(*args, **kwargs) diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py index 56fa7d3..5d9b052 100644 --- a/diffsynth/core/loader/model.py +++ b/diffsynth/core/loader/model.py @@ -3,21 +3,24 @@ from ..vram.disk_map import DiskMap from ..vram.layers import enable_vram_management from .file import load_state_dict import torch +from contextlib import contextmanager +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils import ContextManagers -def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, + use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): config = {} if config is None else config - # Why do we use `skip_model_initialization`? - # It skips the random initialization of model parameters, - # thereby speeding up model loading and avoiding excessive memory usage. - with skip_model_initialization(): + with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)): model = model_class(**config) # What is `module_map`? # This is a module mapping table for VRAM management. if module_map is not None: - devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] + devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], + vram_config["computation_device"]] device = [d for d in devices if d != "disk"][0] - dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], + vram_config["computation_dtype"]] dtype = [d for d in dtypes if d != "disk"][0] if vram_config["offload_device"] != "disk": state_dict = DiskMap(path, device, torch_dtype=dtype) @@ -26,10 +29,12 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic else: state_dict = {i: state_dict[i] for i in state_dict} model.load_state_dict(state_dict, assign=True) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, + vram_limit=vram_limit) else: disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, + vram_limit=vram_limit) else: # Why do we use `DiskMap`? # Sometimes a model file contains multiple models, @@ -46,7 +51,11 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic state_dict = state_dict_converter(state_dict) else: state_dict = {i: state_dict[i] for i in state_dict} - model.load_state_dict(state_dict, assign=True) + if is_deepspeed_zero3_enabled(): + from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model + _load_state_dict_into_zero3_model(model, state_dict) + else: + model.load_state_dict(state_dict, assign=True) # Why do we call `to()`? # Because some models override the behavior of `to()`, # especially those from libraries like Transformers. @@ -56,7 +65,8 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic return model -def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None): +def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", + state_dict_converter=None, module_map=None): if isinstance(path, str): path = [path] config = {} if config is None else config @@ -77,3 +87,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor } enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) return model + + +def get_init_context(torch_dtype, device): + if is_deepspeed_zero3_enabled(): + from transformers.modeling_utils import set_zero3_state + import deepspeed + # Why do we use "deepspeed.zero.Init"? + # Weight segmentation of the model can be performed on the CPU side + # and loading the segmented weights onto the computing card + init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()] + else: + # Why do we use `skip_model_initialization`? + # It skips the random initialization of model parameters, + # thereby speeding up model loading and avoiding excessive memory usage. + init_contexts = [skip_model_initialization()] + + return init_contexts diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py index 6d2792f..ab6bdb9 100644 --- a/diffsynth/diffusion/logger.py +++ b/diffsynth/diffusion/logger.py @@ -18,8 +18,8 @@ class ModelLogger: def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) @@ -34,8 +34,8 @@ class ModelLogger: def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name): accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index f6e2263..6e26035 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -27,7 +27,7 @@ def launch_training_task( optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) - + model.to(device=accelerator.device) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): @@ -59,6 +59,7 @@ def launch_data_process_task( num_workers = args.dataset_num_workers dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) + model.to(device=accelerator.device) model, dataloader = accelerator.prepare(model, dataloader) for data_id, data in enumerate(tqdm(dataloader)): diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index d24e29d..e43e6a8 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -171,7 +171,7 @@ class Resample(nn.Module): torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 - return x + return x, feat_cache, feat_idx def init_weight(self, conv): conv_weight = conv.weight @@ -298,7 +298,7 @@ class ResidualBlock(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + h + return x + h, feat_cache, feat_idx class AttentionBlock(nn.Module): @@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module): for module in self.downsamples: x = module(x, feat_cache, feat_idx) - return x + self.avg_shortcut(x_copy) + return x + self.avg_shortcut(x_copy), feat_cache, feat_idx class Up_ResidualBlock(nn.Module): @@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module): x_shortcut = self.avg_shortcut(x, first_chunk) return x_main + x_shortcut else: - return x_main + return x_main, feat_cache, feat_idx class Encoder3d(nn.Module): @@ -586,14 +586,14 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## middle for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -614,7 +614,7 @@ class Encoder3d(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx class Encoder3d_38(nn.Module): @@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module): else: x = layer(x) - return x + return x, feat_cache, feat_idx class Decoder3d(nn.Module): @@ -807,14 +807,14 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## upsamples for layer in self.upsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -835,7 +835,7 @@ class Decoder3d(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx @@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module): for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## upsamples for layer in self.upsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx, first_chunk) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk) else: x = layer(x) @@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx def count_conv3d(model): @@ -990,11 +990,11 @@ class VideoVAE_(nn.Module): for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder(x[:, :, :1, :, :], + out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: - out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = torch.cat([out, out_], 2) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index d5dc35b..c68dcb9 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -348,7 +348,7 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): attention_mask = torch.cat(all_attention_masks, dim=0).to(device) # Forward pass through the model - with torch.inference_mode(): + with torch.no_grad(): output = text_encoder( input_ids=input_ids, attention_mask=attention_mask, diff --git a/examples/flux/model_training/full/accelerate_config_zero3.yaml b/examples/flux/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux2/model_training/full/accelerate_config_zero3.yaml b/examples/flux2/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/flux2/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh b/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh new file mode 100644 index 0000000..ed678f2 --- /dev/null +++ b/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh @@ -0,0 +1,36 @@ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --task "sft:data_process" + +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config_zero3.yaml examples/flux2/model_training/train.py \ + --dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:transformer/*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --initialize_model_on_cpu \ + --task "sft:train" diff --git a/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh b/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh new file mode 100644 index 0000000..57755ac --- /dev/null +++ b/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh @@ -0,0 +1,34 @@ +# This script is tested on 8*910B(NPU) +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing + +# Edit +# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-9B_full" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index ea727b8..6101687 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -85,6 +85,7 @@ def flux2_parser(): parser = add_general_config(parser) parser = add_image_size_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser @@ -126,7 +127,7 @@ if __name__ == "__main__": fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, - device=accelerator.device, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, ) model_logger = ModelLogger( args.output_path, diff --git a/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml b/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh b/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh new file mode 100644 index 0000000..24a7a58 --- /dev/null +++ b/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh @@ -0,0 +1,16 @@ +# This script was tested using zero3 and on 8*910B(NPU) +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit-2509_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml b/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/z_image/model_training/full/accelerate_config_zero3.yaml b/examples/z_image/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/z_image/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false