qwen-image splited training

This commit is contained in:
Artiprocher
2025-09-02 16:44:14 +08:00
parent 260e32217f
commit b6da77e468
7 changed files with 221 additions and 14 deletions

View File

@@ -174,9 +174,12 @@ class QwenImagePipeline(BasePipeline):
computation_dtype=self.torch_dtype,
computation_device="cuda",
)
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
if self.text_encoder is not None:
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
if self.dit is not None:
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
if self.vae is not None:
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):

View File

@@ -214,7 +214,7 @@ class LoadTorchPickle(DataProcessingOperator):
self.map_location = map_location
def __call__(self, data):
return torch.load(data, map_location=self.map_location)
return torch.load(data, map_location=self.map_location, weights_only=False)
@@ -306,7 +306,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
def __getitem__(self, data_id):
if self.load_from_cache:
data = self.cached_data[data_id % len(self.data)].copy()
data = self.cached_data[data_id % len(self.cached_data)]
data = self.cached_data_operator(data)
else:
data = self.data[data_id % len(self.data)].copy()

View File

@@ -417,6 +417,13 @@ class DiffusionTrainingModule(torch.nn.Module):
state_dict_[name] = param
state_dict = state_dict_
return state_dict
def transfer_data_to_device(self, data, device):
for key in data:
if isinstance(data[key], torch.Tensor):
data[key] = data[key].to(device)
return data
@@ -484,7 +491,10 @@ def launch_training_task(
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
@@ -494,16 +504,24 @@ def launch_training_task(
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
def launch_data_process_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator()
model, dataloader = accelerator.prepare(model, dataloader)
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
for data_id, data in enumerate(tqdm(dataloader)):
with torch.no_grad():
inputs = model.forward_preprocess(data)
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
for data_id, data in tqdm(enumerate(dataloader)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
torch.save(data, save_path)