mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
support target fps
This commit is contained in:
@@ -12,7 +12,7 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class TextVideoDataset(torch.utils.data.Dataset):
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
|
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||||
metadata = pd.read_csv(metadata_path)
|
metadata = pd.read_csv(metadata_path)
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
self.text = metadata["text"].to_list()
|
self.text = metadata["text"].to_list()
|
||||||
@@ -23,6 +23,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
self.is_i2v = is_i2v
|
self.is_i2v = is_i2v
|
||||||
|
self.target_fps = target_fps
|
||||||
|
|
||||||
self.frame_process = v2.Compose([
|
self.frame_process = v2.Compose([
|
||||||
v2.CenterCrop(size=(height, width)),
|
v2.CenterCrop(size=(height, width)),
|
||||||
@@ -71,8 +72,15 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def load_video(self, file_path):
|
def load_video(self, file_path):
|
||||||
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
|
start_frame_id = 0
|
||||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
|
if self.target_fps is None:
|
||||||
|
frame_interval = self.frame_interval
|
||||||
|
else:
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
fps = reader.get_meta_data()["fps"]
|
||||||
|
reader.close()
|
||||||
|
frame_interval = max(round(fps / self.target_fps), 1)
|
||||||
|
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user