support target fps

This commit is contained in:
Artiprocher
2025-03-18 17:30:13 +08:00
parent b1fabbc6b0
commit 4b2b3dda94

View File

@@ -12,7 +12,7 @@ import numpy as np
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
@@ -23,6 +23,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
self.height = height
self.width = width
self.is_i2v = is_i2v
self.target_fps = target_fps
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
@@ -71,8 +72,15 @@ class TextVideoDataset(torch.utils.data.Dataset):
def load_video(self, file_path):
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
start_frame_id = 0
if self.target_fps is None:
frame_interval = self.frame_interval
else:
reader = imageio.get_reader(file_path)
fps = reader.get_meta_data()["fps"]
reader.close()
frame_interval = max(round(fps / self.target_fps), 1)
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
return frames