diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 8d37277..d627dab 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -11,8 +11,9 @@ class VideoDataset(torch.utils.data.Dataset): def __init__( self, base_path=None, metadata_path=None, - frame_interval=1, num_frames=81, - dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None, + num_frames=81, + time_division_factor=4, time_division_remainder=1, + max_pixels=1920*1080, height=None, width=None, height_division_factor=16, width_division_factor=16, data_file_keys=("video",), image_file_extension=("jpg", "jpeg", "png", "webp"), @@ -25,17 +26,15 @@ class VideoDataset(torch.utils.data.Dataset): metadata_path = args.dataset_metadata_path height = args.height width = args.width + max_pixels = args.max_pixels num_frames = args.num_frames data_file_keys = args.data_file_keys.split(",") repeat = args.dataset_repeat - - metadata = pd.read_csv(metadata_path) - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] self.base_path = base_path - self.frame_interval = frame_interval self.num_frames = num_frames - self.dynamic_resolution = dynamic_resolution + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder self.max_pixels = max_pixels self.height = height self.width = width @@ -46,9 +45,43 @@ class VideoDataset(torch.utils.data.Dataset): self.video_file_extension = video_file_extension self.repeat = repeat - if height is not None and width is not None and dynamic_resolution == True: + if height is not None and width is not None: print("Height and width are fixed. Setting `dynamic_resolution` to False.") self.dynamic_resolution = False + elif height is None and width is None: + print("Height and width are none. Setting `dynamic_resolution` to True.") + self.dynamic_resolution = True + + if metadata_path is None: + print("No metadata. Trying to generate it.") + metadata = self.generate_metadata(base_path) + print(f"{len(metadata)} lines in metadata.") + else: + metadata = pd.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + + def generate_metadata(self, folder): + video_list, prompt_list = [], [] + file_set = set(os.listdir(folder)) + for file_name in file_set: + if "." not in file_name: + continue + file_ext_name = file_name.split(".")[-1].lower() + file_base_name = file_name[:-len(file_ext_name)-1] + if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension: + continue + prompt_file_name = file_base_name + ".txt" + if prompt_file_name not in file_set: + continue + with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f: + prompt = f.read().strip() + video_list.append(file_name) + prompt_list.append(prompt) + metadata = pd.DataFrame() + metadata["video"] = video_list + metadata["prompt"] = prompt_list + return metadata def crop_and_resize(self, image, target_height, target_width): @@ -75,15 +108,22 @@ class VideoDataset(torch.utils.data.Dataset): height, width = self.height, self.width return height, width + + def get_num_frames(self, reader): + num_frames = self.num_frames + if int(reader.count_frames()) < num_frames: + num_frames = int(reader.count_frames()) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + - def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames): + def load_video(self, file_path): reader = imageio.get_reader(file_path) - if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: - reader.close() - return None + num_frames = self.get_num_frames(reader) frames = [] for frame_id in range(num_frames): - frame = reader.get_data(start_frame_id + frame_id * interval) + frame = reader.get_data(frame_id) frame = Image.fromarray(frame) frame = self.crop_and_resize(frame, *self.get_height_width(frame)) frames.append(frame) @@ -95,11 +135,6 @@ class VideoDataset(torch.utils.data.Dataset): image = Image.open(file_path).convert("RGB") image = self.crop_and_resize(image, *self.get_height_width(image)) return image - - - def load_video(self, file_path): - frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames) - return frames def is_image(self, file_path): @@ -182,34 +217,50 @@ class DiffusionTrainingModule(torch.nn.Module): -def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None): - if args is not None: - learning_rate = args.learning_rate - num_epochs = args.num_epochs - output_path = args.output_path - remove_prefix_in_ckpt = args.remove_prefix_in_ckpt - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) +class ModelLogger: + def __init__(self, output_path, remove_prefix_in_ckpt=None): + self.output_path = output_path + self.remove_prefix_in_ckpt = remove_prefix_in_ckpt + - accelerator = Accelerator(gradient_accumulation_steps=1) + def on_step_end(self, loss): + pass + + + def on_epoch_end(self, accelerator, model, epoch_id): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + + +def launch_training_task( + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + num_epochs: int = 1, + gradient_accumulation_steps: int = 1, +): + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) + accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) - for epoch in range(num_epochs): + for epoch_id in range(num_epochs): for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() loss = model(data) accelerator.backward(loss) optimizer.step() - scheduler.step() - accelerator.wait_for_everyone() - if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) - state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt) - os.makedirs(output_path, exist_ok=True) - path = os.path.join(output_path, f"epoch-{epoch}.safetensors") - accelerator.save(state_dict, path, safe_serialization=True) + model_logger.on_step_end(loss) + scheduler.step() + model_logger.on_epoch_end(accelerator, model, epoch_id) @@ -228,8 +279,9 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat def wan_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.") - parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.") + parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") + parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..") parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") @@ -247,5 +299,6 @@ def wan_parser(): parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") return parser diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 871bbd6..877c5b8 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -1,6 +1,6 @@ import torch, os, json from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -107,4 +107,14 @@ if __name__ == "__main__": use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, ) - launch_training_task(model, dataset, args=args) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt + ) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) + launch_training_task( + dataset, model, model_logger, optimizer, scheduler, + num_epochs=args.num_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + )