import os, torch from tqdm import tqdm from accelerate import Accelerator from .training_module import DiffusionTrainingModule from .logger import ModelLogger def launch_training_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, learning_rate: float = 1e-5, weight_decay: float = 1e-2, num_workers: int = 1, save_steps: int = None, num_epochs: int = 1, args = None, ): if args is not None: learning_rate = args.learning_rate weight_decay = args.weight_decay num_workers = args.dataset_num_workers save_steps = args.save_steps num_epochs = args.num_epochs optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() if dataset.load_from_cache: loss = model({}, inputs=data) else: loss = model(data) accelerator.backward(loss) optimizer.step() model_logger.on_step_end(accelerator, model, save_steps, loss=loss) scheduler.step() if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) model_logger.on_training_end(accelerator, model, save_steps) def launch_data_process_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, num_workers: int = 8, args = None, ): if args is not None: num_workers = args.dataset_num_workers dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) model, dataloader = accelerator.prepare(model, dataloader) for data_id, data in enumerate(tqdm(dataloader)): with accelerator.accumulate(model): with torch.no_grad(): folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) os.makedirs(folder, exist_ok=True) save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") data = model(data) torch.save(data, save_path)