support num_workers,save_steps,find_unused_parameters

This commit is contained in:
mi804
2025-08-06 10:52:59 +08:00
parent 8d2f6ad32e
commit 6bae70eee0
13 changed files with 71 additions and 16 deletions

View File

@@ -4,6 +4,7 @@ from PIL import Image
import pandas as pd
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
@@ -369,34 +370,42 @@ class ModelLogger:
def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
def on_model_save(self, accelerator, model, step_id=None, epoch_id=None):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
if step_id is not None:
path = os.path.join(self.output_path, f"step-{step_id}.safetensors")
else:
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
num_workers: int = 8,
save_steps: int = None,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False,
):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
for step_id, data in enumerate(tqdm(dataloader)):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
@@ -404,8 +413,16 @@ def launch_training_task(
optimizer.step()
model_logger.on_step_end(loss)
scheduler.step()
model_logger.on_epoch_end(accelerator, model, epoch_id)
global_steps = epoch_id * len(dataloader) + step_id + 1
# save every `save_steps` steps
if save_steps is not None and global_steps % save_steps == 0:
model_logger.on_model_save(accelerator, model, step_id=global_steps)
# save the model at the end of each epoch if save_steps is None
if save_steps is None:
model_logger.on_model_save(accelerator, model, epoch_id=epoch_id)
# save the final model if save_steps is not None
if save_steps is not None:
model_logger.on_model_save(accelerator, model, step_id=global_steps)
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
@@ -446,6 +463,9 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser
@@ -474,6 +494,9 @@ def flux_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser
@@ -503,4 +526,7 @@ def qwen_image_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser