mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Merge pull request #843 from wuutiing/main
add read gifs as video support
This commit is contained in:
@@ -154,7 +154,7 @@ class VideoDataset(torch.utils.data.Dataset):
|
|||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
data_file_keys=("video",),
|
data_file_keys=("video",),
|
||||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||||
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
|
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", "gif"),
|
||||||
repeat=1,
|
repeat=1,
|
||||||
args=None,
|
args=None,
|
||||||
):
|
):
|
||||||
@@ -259,8 +259,53 @@ class VideoDataset(torch.utils.data.Dataset):
|
|||||||
num_frames -= 1
|
num_frames -= 1
|
||||||
return num_frames
|
return num_frames
|
||||||
|
|
||||||
|
def _load_gif(self, file_path):
|
||||||
|
gif_img = Image.open(file_path)
|
||||||
|
frame_count = 0
|
||||||
|
delays, frames = [], []
|
||||||
|
while True:
|
||||||
|
delay = gif_img.info.get('duration', 100) # ms
|
||||||
|
delays.append(delay)
|
||||||
|
rgb_frame = gif_img.convert("RGB")
|
||||||
|
croped_frame = self.crop_and_resize(rgb_frame, *self.get_height_width(rgb_frame))
|
||||||
|
frames.append(croped_frame)
|
||||||
|
frame_count += 1
|
||||||
|
try:
|
||||||
|
gif_img.seek(frame_count)
|
||||||
|
except:
|
||||||
|
break
|
||||||
|
# delays canbe used to calculate framerates
|
||||||
|
# i guess it is better to sample images with stable interval,
|
||||||
|
# and using minimal_interval as the interval,
|
||||||
|
# and framerate = 1000 / minimal_interval
|
||||||
|
if any((delays[0] != i) for i in delays):
|
||||||
|
minimal_interval = min([i for i in delays if i > 0])
|
||||||
|
# make a ((start,end),frameid) struct
|
||||||
|
start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))]
|
||||||
|
_frames = []
|
||||||
|
# according gemini-code-assist, make it more efficient to locate
|
||||||
|
# where to sample the frame
|
||||||
|
last_match = 0
|
||||||
|
for i in range(sum(delays) // minimal_interval):
|
||||||
|
current_time = minimal_interval * i
|
||||||
|
for idx, ((start, end), frame_idx) in enumerate(start_end_idx_map[last_match:]):
|
||||||
|
if start <= current_time < end:
|
||||||
|
_frames.append(frames[frame_idx])
|
||||||
|
last_match = idx + last_match
|
||||||
|
break
|
||||||
|
frames = _frames
|
||||||
|
num_frames = len(frames)
|
||||||
|
if num_frames > self.num_frames:
|
||||||
|
num_frames = self.num_frames
|
||||||
|
else:
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
frames = frames[:num_frames]
|
||||||
|
return frames
|
||||||
|
|
||||||
def load_video(self, file_path):
|
def load_video(self, file_path):
|
||||||
|
if file_path.lower().endswith(".gif"):
|
||||||
|
return self._load_gif(file_path)
|
||||||
reader = imageio.get_reader(file_path)
|
reader = imageio.get_reader(file_path)
|
||||||
num_frames = self.get_num_frames(reader)
|
num_frames = self.get_num_frames(reader)
|
||||||
frames = []
|
frames = []
|
||||||
|
|||||||
Reference in New Issue
Block a user