diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 45a026f..cdf0d67 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -12,7 +12,7 @@ import numpy as np class TextVideoDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None): metadata = pd.read_csv(metadata_path) self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] self.text = metadata["text"].to_list() @@ -23,6 +23,7 @@ class TextVideoDataset(torch.utils.data.Dataset): self.height = height self.width = width self.is_i2v = is_i2v + self.target_fps = target_fps self.frame_process = v2.Compose([ v2.CenterCrop(size=(height, width)), @@ -71,8 +72,15 @@ class TextVideoDataset(torch.utils.data.Dataset): def load_video(self, file_path): - start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] - frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + start_frame_id = 0 + if self.target_fps is None: + frame_interval = self.frame_interval + else: + reader = imageio.get_reader(file_path) + fps = reader.get_meta_data()["fps"] + reader.close() + frame_interval = max(round(fps / self.target_fps), 1) + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process) return frames