From 1ed676b0767ae9c37a32a14de87ad870dcb48a1c Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 6 May 2025 17:53:56 +0800 Subject: [PATCH] train --- diffsynth/pipelines/flux_image.py | 52 ++-- test.py | 99 ++++++-- train.py | 408 ++++++++++++++++++++++++++++++ 3 files changed, 517 insertions(+), 42 deletions(-) create mode 100644 train.py diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 33b1140..5c459a8 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -680,6 +680,7 @@ def lets_dance_flux( step1x_mask=None, step1x_reference_latents=None, tea_cache: TeaCache = None, + use_gradient_checkpointing=False, **kwargs ): if tiled: @@ -774,20 +775,32 @@ def lets_dance_flux( tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) else: tea_cache_update = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward if tea_cache_update: hidden_states = tea_cache.update(hidden_states) else: # Joint Blocks for block_id, block in enumerate(dit.blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) - ) + if use_gradient_checkpointing: + hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), + use_reentrant=False, + ) + else: + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states = hidden_states + controlnet_res_stack[block_id] @@ -796,14 +809,21 @@ def lets_dance_flux( hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) num_joint_blocks = len(dit.blocks) for block_id, block in enumerate(dit.single_blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) - ) + if use_gradient_checkpointing: + hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + use_reentrant=False, + ) + else: + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] diff --git a/test.py b/test.py index 8661c4d..5e0a073 100644 --- a/test.py +++ b/test.py @@ -2,9 +2,10 @@ from transformers import AutoConfig, AutoTokenizer import torch from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor -from diffsynth import ModelManager, FluxImagePipeline, load_state_dict +from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys from qwen_vl_utils import smart_resize from PIL import Image +import numpy as np @@ -15,6 +16,8 @@ class NexusGenQwenVLEncoder(torch.nn.Module): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device) self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>" + self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>" @staticmethod def from_pretrained(model_path, torch_dtype="auto", device="cpu"): @@ -35,34 +38,62 @@ class NexusGenQwenVLEncoder(torch.nn.Module): images[j] = input_image.resize((resized_width, resized_height)) return images - def forward(self, prompt, images=None): - messages = [{ - "role": "user", - "content": [{ - "type": "text", - "text": prompt - },], - }] - text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + def forward(self, prompt, images=None, num_img_tokens=81): + messages = [ + { + "role": "user", + "content": [{ + "type": "text", + "text": prompt + },], + }, + { + "role": "assistant", + "content": [{ + "type": "text", + "text": self.t2i_template if images is None else self.i2i_template + },], + } + ] + images = self.process_images(images) + target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8)) + if images is None: + images = [target_image] + else: + images = images + [target_image] + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) inputs = self.processor( text=[text], - images=self.process_images(images), + images=images, padding=True, return_tensors="pt", ) inputs = inputs.to(self.model.device) - generation_image_grid_thw = torch.tensor([[1, 18, 18]]).to(self.model.device) - outputs = self.model.generate(**inputs, - max_new_tokens=1024, - return_dict_in_generate=True, - generation_image_grid_thw=generation_image_grid_thw) - output_image_embeddings = outputs['output_image_embeddings'] + + input_embeds = self.model.model.embed_tokens(inputs['input_ids']) + image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw']) + ground_truth_image_embeds = image_embeds[-num_img_tokens:] + input_image_embeds = image_embeds[:-num_img_tokens] + + image_mask = inputs['input_ids'] == self.model.config.image_token_id + indices = image_mask.cumsum(dim=1) + input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask) + gt_image_mask = torch.logical_and(image_mask, ~input_image_mask) + input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds) + + position_ids, _ = self.model.get_rope_index(inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) + position_ids = position_ids.contiguous() + outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) + output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right + output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]] + output_image_embeddings = output_image_embeddings.unsqueeze(0) return output_image_embeddings - - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -70,16 +101,32 @@ model_manager.load_models([ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" ]) pipe = FluxImagePipeline.from_model_manager(model_manager) -pipe.dit.load_state_dict(load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16), strict=False) + +state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict, strict=False) + +adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda") +adapter.load_state_dict(state_dict, strict=False) qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda") -adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda") -adapter.load_state_dict(load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16), strict=False) - with torch.no_grad(): - instruction = "<|vision_start|><|image_pad|><|vision_end|> Transform the style to flat anime. Keep the color." - emb = qwenvl(instruction, images=[Image.open("image_3.jpg").convert('RGB')]) + instruction = "Generate an image according to the following description: a beautiful Asian girl" + emb = qwenvl(instruction, images=None) emb = adapter(emb) image = pipe("", image_emb=emb) - image.save("image_4.jpg") + image.save("image_1.jpg") + +with torch.no_grad(): + instruction = "<|vision_start|><|image_pad|><|vision_end|> Add sunglasses." + emb = qwenvl(instruction, images=[Image.open("image_1.jpg")]) + emb = adapter(emb) + image = pipe("", image_emb=emb) + image.save("image_2.jpg") + +with torch.no_grad(): + instruction = "<|vision_start|><|image_pad|><|vision_end|> Let her smile." + emb = qwenvl(instruction, images=[Image.open("image_2.jpg")]) + emb = adapter(emb) + image = pipe("", image_emb=emb) + image.save("image_3.jpg") diff --git a/train.py b/train.py new file mode 100644 index 0000000..aaa11e7 --- /dev/null +++ b/train.py @@ -0,0 +1,408 @@ +from diffsynth import ModelManager, FluxImagePipeline, load_state_dict +from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task +from diffsynth.models.lora import FluxLoRAConverter +import torch, os, argparse +from diffsynth.pipelines.flux_image import lets_dance_flux +from accelerate import Accelerator +from tqdm import tqdm +import torch, os, json, torchvision +from PIL import Image +from torchvision.transforms import v2 +from transformers import AutoConfig, AutoTokenizer +import torch +from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration +from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor +from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys +from qwen_vl_utils import smart_resize +from PIL import Image +import numpy as np +import lightning as pl +os.environ["TOKENIZERS_PARALLELISM"] = "True" + + + +class SingleTaskDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path, + keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")), + height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None + ): + self.base_path = base_path + self.keys = keys + self.metadata = [] + self.bad_data = [] + self.height = height + self.width = width + self.random = random + self.steps_per_epoch = steps_per_epoch + self.image_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + if metadata_path is None: + self.search_for_data("", report_data_log=True) + self.report_data_log() + else: + with open(metadata_path, "r", encoding="utf-8-sig") as f: + self.metadata = json.load(f) + + + def report_data_log(self): + print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.") + + + def dump_metadata(self, path): + with open(path, "w", encoding="utf-8") as f: + json.dump(self.metadata, f, ensure_ascii=False, indent=4) + + + def parse_json_file(self, absolute_path, relative_path): + data_list = [] + with open(absolute_path, "r") as f: + metadata = json.load(f) + for image_1, image_2, instruction in self.keys: + image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None + image_2 = os.path.join(relative_path, metadata[image_2]) + instruction = metadata[instruction] + data_list.append((image_1, image_2, instruction)) + return data_list + + + def search_for_data(self, path, report_data_log=False): + now_path = os.path.join(self.base_path, path) + if os.path.isfile(now_path) and path.endswith(".json"): + try: + data_list = self.parse_json_file(now_path, os.path.dirname(path)) + self.metadata.extend(data_list) + except: + self.bad_data.append(now_path) + elif os.path.isdir(now_path): + for sub_path in os.listdir(now_path): + self.search_for_data(os.path.join(path, sub_path)) + if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)): + self.report_data_log() + + + def load_image(self, image_path, skip_process=False): + image_path = os.path.join(self.base_path, image_path) + image = Image.open(image_path).convert("RGB") + if skip_process: + return image + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = self.image_process(image) + return image + + + def load_data(self, data_id): + image_1, image_2, instruction = self.metadata[data_id] + image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None + image_2 = self.load_image(image_2) + return {"image_1": image_1, "image_2": image_2, "instruction": instruction} + + + def __getitem__(self, data_id): + if self.random: + while True: + try: + data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata) + data = self.load_data(data_id) + return data + except: + continue + else: + return self.load_data(data_id) + + + def __len__(self): + return self.steps_per_epoch if self.random else len(self.metadata) + + + +class MultiTaskDataset(torch.utils.data.Dataset): + def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000): + self.dataset_list = dataset_list + self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float) + self.steps_per_epoch = steps_per_epoch + + + def __getitem__(self, data_id): + dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0] + data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0] + data = self.dataset_list[dataset_id][data_id] + return data + + + def __len__(self): + return self.steps_per_epoch + + + +class NexusGenQwenVLEncoder(torch.nn.Module): + def __init__(self, model_path, torch_dtype="auto", device="cpu"): + super().__init__() + model_config = AutoConfig.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device) + self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>" + self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>" + + @staticmethod + def from_pretrained(model_path, torch_dtype="auto", device="cpu"): + return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval() + + def process_images(self, images=None): + if images is None: + return None + # resize input to max_pixels to avoid oom + for j in range(len(images)): + input_image = images[j] + input_w, input_h = input_image.size + resized_height, resized_width = smart_resize( + input_h, + input_w, + max_pixels=262640, + ) + images[j] = input_image.resize((resized_width, resized_height)) + return images + + def forward(self, prompt, images=None, num_img_tokens=81): + messages = [ + { + "role": "user", + "content": [{ + "type": "text", + "text": prompt + },], + }, + { + "role": "assistant", + "content": [{ + "type": "text", + "text": self.t2i_template if images is None else self.i2i_template + },], + } + ] + images = self.process_images(images) + target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8)) + if images is None: + images = [target_image] + else: + images = images + [target_image] + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + inputs = self.processor( + text=[text], + images=images, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(self.model.device) + + input_embeds = self.model.model.embed_tokens(inputs['input_ids']) + image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw']) + ground_truth_image_embeds = image_embeds[-num_img_tokens:] + input_image_embeds = image_embeds[:-num_img_tokens] + + image_mask = inputs['input_ids'] == self.model.config.image_token_id + indices = image_mask.cumsum(dim=1) + input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask) + gt_image_mask = torch.logical_and(image_mask, ~input_image_mask) + input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds) + + position_ids, _ = self.model.get_rope_index(inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) + position_ids = position_ids.contiguous() + outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) + output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right + output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]] + output_image_embeddings = output_image_embeddings.unsqueeze(0) + return output_image_embeddings + + + +class UnifiedModel(pl.LightningModule): + def __init__(self, flux_text_encoder_path, flux_vae_path, flux_dit_path, flux_decoder_path, qwenvl_path): + super().__init__() + # Load models + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + flux_text_encoder_path, + flux_vae_path, + flux_dit_path + ]) + self.pipe = FluxImagePipeline.from_model_manager(model_manager) + + state_dict = load_state_dict(flux_decoder_path, torch_dtype=torch.bfloat16) + self.pipe.dit.load_state_dict(state_dict, strict=False) + + self.adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16) + self.adapter.load_state_dict(state_dict, strict=False) + + self.qwenvl = NexusGenQwenVLEncoder.from_pretrained(qwenvl_path) + + self.pipe.vae_decoder.requires_grad_(False) + self.pipe.vae_encoder.requires_grad_(False) + self.pipe.text_encoder_1.requires_grad_(False) + self.qwenvl.requires_grad_(False) + + self.pipe.scheduler.set_timesteps(1000, training=True) + + + def training_step(self, batch, batch_idx): + # Data + text, image = batch["instruction"], batch["image_2"] + image_ref = batch["image_1"] + image = image.unsqueeze(0) + + # Prepare input parameters + self.pipe.device = self.device + latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + if image_ref is None: + instruction = f"Generate an image according to the following description: {text}" + images_ref = None + else: + instruction = f"<|vision_start|><|image_pad|><|vision_end|> {text}" + images_ref = [image_ref] + emb = self.qwenvl(instruction, images=images_ref) + emb = self.adapter(emb) + prompt_emb = self.pipe.encode_prompt("", positive=True, image_emb=emb) + + noise_pred = lets_dance_flux( + self.pipe.denoising_model(), + hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + image_emb=emb, + use_gradient_checkpointing=True + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + return loss + + + def forward(self, batch): + return self.training_step(batch, 0) + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + +def search_for_last_checkpoint(path): + if not os.path.exists(path): + return None, 0 + checkpoint_list = os.listdir(path) + checkpoint_list = [int(checkpoint.split("-")[1]) for checkpoint in checkpoint_list if checkpoint.startswith("epoch")] + if len(checkpoint_list) == 0: + return None, 0 + else: + max_epoch_id = max(checkpoint_list) + return f"{path}/epoch-{max_epoch_id}/model.safetensors", max_epoch_id + 1 + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="gradient_accumulation_steps", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=100, + help="steps_per_epoch", + ) + parser.add_argument( + "--output_path", + type=str, + default="./models", + help="output_path", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="learning_rate", + ) + args = parser.parse_args() + return args + + + +if __name__ == '__main__': + args = parse_args() + model = UnifiedModel( + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors", + "models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", + "models/DiffSynth-Studio/Nexus-Gen", + ) + # dataset and data loader + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + dataset = MultiTaskDataset( + dataset_list=[ + SingleTaskDataset( + "data/example_dataset", + keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")), + height=512, width=512, + ), + ], + dataset_weight=(1,), + steps_per_epoch=args.steps_per_epoch * accelerator.num_processes, + ) + train_loader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=1, + collate_fn=lambda x: x[0] + ) + # train + pretrained_path, start_epoch_id = search_for_last_checkpoint(args.output_path) + if pretrained_path is not None: + print(f"pretrained_path: {pretrained_path}") + model.load_state_dict(load_state_dict(pretrained_path, torch_dtype=torch.bfloat16), assign=True, strict=False) + + model.to(torch.bfloat16) + model.to(accelerator.device) + + trainable_modules = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=args.learning_rate) + + model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) + + for epoch in range(start_epoch_id, 100000): + for batch in tqdm(train_loader, desc=f"epoch-{epoch}", disable=not accelerator.is_local_main_process): + with accelerator.accumulate(model): + optimizer.zero_grad() + loss = model(batch) + accelerator.backward(loss) + optimizer.step() + path = args.output_path + os.makedirs(path, exist_ok=True) + accelerator.wait_for_everyone() + accelerator.save_model(model, f"{path}/epoch-{epoch}", max_shard_size="10GB", safe_serialization=True)