mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support target fps
This commit is contained in:
@@ -12,7 +12,7 @@ import numpy as np
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
@@ -23,6 +23,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.is_i2v = is_i2v
|
||||
self.target_fps = target_fps
|
||||
|
||||
self.frame_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
@@ -71,8 +72,15 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
|
||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
|
||||
start_frame_id = 0
|
||||
if self.target_fps is None:
|
||||
frame_interval = self.frame_interval
|
||||
else:
|
||||
reader = imageio.get_reader(file_path)
|
||||
fps = reader.get_meta_data()["fps"]
|
||||
reader.close()
|
||||
frame_interval = max(round(fps / self.target_fps), 1)
|
||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user