From e014cad82067c1a2e6db8815979f6f6adcdb9368 Mon Sep 17 00:00:00 2001 From: twu Date: Thu, 21 Aug 2025 09:01:48 +0000 Subject: [PATCH] add read gifs as video support --- diffsynth/trainers/utils.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 065b687..b55b258 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -154,7 +154,7 @@ class VideoDataset(torch.utils.data.Dataset): height_division_factor=16, width_division_factor=16, data_file_keys=("video",), image_file_extension=("jpg", "jpeg", "png", "webp"), - video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), + video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", "gif"), repeat=1, args=None, ): @@ -259,8 +259,41 @@ class VideoDataset(torch.utils.data.Dataset): num_frames -= 1 return num_frames - + def _load_gif(self, file_path): + gif_img = Image.open(file_path) + frame_count = 0 + delays, frames = [], [] + while True: + delay = gif_img.info.get('duration', 100) # ms + delays.append(delay) + rgb_frame = gif_img.convert("RGB") + croped_frame = self.crop_and_resize(rgb_frame, *self.get_height_width(rgb_frame)) + frames.append(croped_frame) + frame_count += 1 + try: + gif_img.seek(frame_count) + except: + break + # delays canbe used to calculate framerates + # i guess it is better to sample images with stable interval, + # and using minimal_interval as the interval, + # and framerate = 1000 / minimal_interval + if any((delays[0] != i) for i in delays): + minimal_interval = min([i for i in delays if i > 0]) + # make a ((start,end),frameid) struct + start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))] + _frames = [] + for i in range(sum(delays) // minimal_interval): + current_time = minimal_interval * i + for ((start, end), frame_idx) in start_end_idx_map: + if start <= current_time < end: + _frames.append(frames[frame_idx]) + frames = _frames + return frames + def load_video(self, file_path): + if file_path.lower().endswith(".gif"): + return self._load_gif(file_path) reader = imageio.get_reader(file_path) num_frames = self.get_num_frames(reader) frames = []