import imageio, os, torch, warnings, torchvision from peft import LoraConfig, inject_adapter_in_model from PIL import Image import pandas as pd from tqdm import tqdm from accelerate import Accelerator class VideoDataset(torch.utils.data.Dataset): def __init__( self, base_path, metadata_path, frame_interval=1, num_frames=81, dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None, height_division_factor=16, width_division_factor=16, data_file_keys=("video",), image_file_extension=("jpg", "jpeg", "png", "webp"), video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), repeat=1, ): metadata = pd.read_csv(metadata_path) self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] self.base_path = base_path self.frame_interval = frame_interval self.num_frames = num_frames self.dynamic_resolution = dynamic_resolution self.max_pixels = max_pixels self.height = height self.width = width self.height_division_factor = height_division_factor self.width_division_factor = width_division_factor self.data_file_keys = data_file_keys self.image_file_extension = image_file_extension self.video_file_extension = video_file_extension self.repeat = repeat if height is not None and width is not None and dynamic_resolution == True: print("Height and width are fixed. Setting `dynamic_resolution` to False.") self.dynamic_resolution = False def crop_and_resize(self, image, target_height, target_width): width, height = image.size scale = max(target_width / width, target_height / height) image = torchvision.transforms.functional.resize( image, (round(height*scale), round(width*scale)), interpolation=torchvision.transforms.InterpolationMode.BILINEAR ) image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) return image def get_height_width(self, image): if self.dynamic_resolution: width, height = image.size if width * height > self.max_pixels: scale = (width * height / self.max_pixels) ** 0.5 height, width = int(height / scale), int(width / scale) height = height // self.height_division_factor * self.height_division_factor width = width // self.width_division_factor * self.width_division_factor else: height, width = self.height, self.width return height, width def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames): reader = imageio.get_reader(file_path) if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: reader.close() return None frames = [] for frame_id in range(num_frames): frame = reader.get_data(start_frame_id + frame_id * interval) frame = Image.fromarray(frame) frame = self.crop_and_resize(frame, *self.get_height_width(frame)) frames.append(frame) reader.close() return frames def load_image(self, file_path): image = Image.open(file_path).convert("RGB") image = self.crop_and_resize(image, *self.get_height_width(image)) return image def load_video(self, file_path): frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames) return frames def is_image(self, file_path): file_ext_name = file_path.split(".")[-1] return file_ext_name.lower() in self.image_file_extension def is_video(self, file_path): file_ext_name = file_path.split(".")[-1] return file_ext_name.lower() in self.video_file_extension def load_data(self, file_path): if self.is_image(file_path): return self.load_image(file_path) elif self.is_video(file_path): return self.load_video(file_path) else: return None def __getitem__(self, data_id): data = self.data[data_id % len(self.data)].copy() for key in self.data_file_keys: if key in data: path = os.path.join(self.base_path, data[key]) data[key] = self.load_data(path) if data[key] is None: warnings.warn(f"cannot load file {data[key]}.") return None return data def __len__(self): return len(self.data) * self.repeat class DiffusionTrainingModule(torch.nn.Module): def __init__(self): super().__init__() def to(self, *args, **kwargs): for name, model in self.named_children(): model.to(*args, **kwargs) return self def trainable_modules(self): trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) return trainable_modules def trainable_param_names(self): trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) return trainable_param_names def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): if lora_alpha is None: lora_alpha = lora_rank lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) model = inject_adapter_in_model(lora_config, model) return model def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate, num_epochs, output_path, remove_prefix=None): dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) accelerator = Accelerator(gradient_accumulation_steps=1) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch in range(num_epochs): for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() loss = model(data) accelerator.backward(loss) optimizer.step() scheduler.step() accelerator.wait_for_everyone() if accelerator.is_main_process: state_dict = accelerator.get_state_dict(model) trainable_param_names = model.trainable_param_names() state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} if remove_prefix is not None: state_dict_ = {} for name, param in state_dict.items(): if name.startswith(remove_prefix): name = name[len(remove_prefix):] state_dict_[name] = param path = os.path.join(output_path, f"epoch-{epoch}") accelerator.save(state_dict_, path, safe_serialization=True)