From e014cad82067c1a2e6db8815979f6f6adcdb9368 Mon Sep 17 00:00:00 2001 From: twu Date: Thu, 21 Aug 2025 09:01:48 +0000 Subject: [PATCH 1/3] add read gifs as video support --- diffsynth/trainers/utils.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 065b687..b55b258 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -154,7 +154,7 @@ class VideoDataset(torch.utils.data.Dataset): height_division_factor=16, width_division_factor=16, data_file_keys=("video",), image_file_extension=("jpg", "jpeg", "png", "webp"), - video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), + video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", "gif"), repeat=1, args=None, ): @@ -259,8 +259,41 @@ class VideoDataset(torch.utils.data.Dataset): num_frames -= 1 return num_frames - + def _load_gif(self, file_path): + gif_img = Image.open(file_path) + frame_count = 0 + delays, frames = [], [] + while True: + delay = gif_img.info.get('duration', 100) # ms + delays.append(delay) + rgb_frame = gif_img.convert("RGB") + croped_frame = self.crop_and_resize(rgb_frame, *self.get_height_width(rgb_frame)) + frames.append(croped_frame) + frame_count += 1 + try: + gif_img.seek(frame_count) + except: + break + # delays canbe used to calculate framerates + # i guess it is better to sample images with stable interval, + # and using minimal_interval as the interval, + # and framerate = 1000 / minimal_interval + if any((delays[0] != i) for i in delays): + minimal_interval = min([i for i in delays if i > 0]) + # make a ((start,end),frameid) struct + start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))] + _frames = [] + for i in range(sum(delays) // minimal_interval): + current_time = minimal_interval * i + for ((start, end), frame_idx) in start_end_idx_map: + if start <= current_time < end: + _frames.append(frames[frame_idx]) + frames = _frames + return frames + def load_video(self, file_path): + if file_path.lower().endswith(".gif"): + return self._load_gif(file_path) reader = imageio.get_reader(file_path) num_frames = self.get_num_frames(reader) frames = [] From e3f47a799b23b54ab23a8d85afc5ec51d2d97a7c Mon Sep 17 00:00:00 2001 From: twu Date: Thu, 21 Aug 2025 09:13:45 +0000 Subject: [PATCH 2/3] make it more efficient to locate where to sample the frame --- diffsynth/trainers/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b55b258..c4b5d92 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -283,11 +283,16 @@ class VideoDataset(torch.utils.data.Dataset): # make a ((start,end),frameid) struct start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))] _frames = [] + # according gemini-code-assist, make it more efficient to locate + # where to sample the frame + last_match = 0 for i in range(sum(delays) // minimal_interval): current_time = minimal_interval * i - for ((start, end), frame_idx) in start_end_idx_map: + for idx, ((start, end), frame_idx) in enumerate(start_end_idx_map[last_match:]): if start <= current_time < end: _frames.append(frames[frame_idx]) + last_match = idx + last_match + break frames = _frames return frames From f6418004bb441598c83f5dd60520f1ae91a247d6 Mon Sep 17 00:00:00 2001 From: twu Date: Fri, 22 Aug 2025 03:00:35 +0000 Subject: [PATCH 3/3] as numframe limit is impled in reader, add that --- diffsynth/trainers/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index c4b5d92..c5e0f19 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -294,6 +294,13 @@ class VideoDataset(torch.utils.data.Dataset): last_match = idx + last_match break frames = _frames + num_frames = len(frames) + if num_frames > self.num_frames: + num_frames = self.num_frames + else: + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + frames = frames[:num_frames] return frames def load_video(self, file_path):