diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 22ea31e..5f02117 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -154,7 +154,7 @@ class VideoDataset(torch.utils.data.Dataset): height_division_factor=16, width_division_factor=16, data_file_keys=("video",), image_file_extension=("jpg", "jpeg", "png", "webp"), - video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), + video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", "gif"), repeat=1, args=None, ): @@ -259,8 +259,53 @@ class VideoDataset(torch.utils.data.Dataset): num_frames -= 1 return num_frames - + def _load_gif(self, file_path): + gif_img = Image.open(file_path) + frame_count = 0 + delays, frames = [], [] + while True: + delay = gif_img.info.get('duration', 100) # ms + delays.append(delay) + rgb_frame = gif_img.convert("RGB") + croped_frame = self.crop_and_resize(rgb_frame, *self.get_height_width(rgb_frame)) + frames.append(croped_frame) + frame_count += 1 + try: + gif_img.seek(frame_count) + except: + break + # delays canbe used to calculate framerates + # i guess it is better to sample images with stable interval, + # and using minimal_interval as the interval, + # and framerate = 1000 / minimal_interval + if any((delays[0] != i) for i in delays): + minimal_interval = min([i for i in delays if i > 0]) + # make a ((start,end),frameid) struct + start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))] + _frames = [] + # according gemini-code-assist, make it more efficient to locate + # where to sample the frame + last_match = 0 + for i in range(sum(delays) // minimal_interval): + current_time = minimal_interval * i + for idx, ((start, end), frame_idx) in enumerate(start_end_idx_map[last_match:]): + if start <= current_time < end: + _frames.append(frames[frame_idx]) + last_match = idx + last_match + break + frames = _frames + num_frames = len(frames) + if num_frames > self.num_frames: + num_frames = self.num_frames + else: + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + frames = frames[:num_frames] + return frames + def load_video(self, file_path): + if file_path.lower().endswith(".gif"): + return self._load_gif(file_path) reader = imageio.get_reader(file_path) num_frames = self.get_num_frames(reader) frames = []