mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
Ltx2.3 i2v training and sample frames with fixed fps (#1339)
* add 2.3 i2v training scripts * add frame resampling by fixed fps * LoadVideo: add compatibility for not fix_frame_rate * refactor frame resampler * minor fix
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import math
|
||||
import torch, torchvision, imageio, os
|
||||
import imageio.v3 as iio
|
||||
from PIL import Image
|
||||
import torchaudio
|
||||
|
||||
|
||||
class DataProcessingPipeline:
|
||||
@@ -105,27 +107,59 @@ class ToList(DataProcessingOperator):
|
||||
return [data]
|
||||
|
||||
|
||||
class LoadVideo(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
class FrameSamplerByRateMixin:
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
self.frame_rate = frame_rate
|
||||
self.fix_frame_rate = fix_frame_rate
|
||||
|
||||
def get_reader(self, data: str):
|
||||
return imageio.get_reader(data)
|
||||
|
||||
def get_available_num_frames(self, reader):
|
||||
if not self.fix_frame_rate:
|
||||
return reader.count_frames()
|
||||
meta_data = reader.get_meta_data()
|
||||
total_original_frames = int(reader.count_frames())
|
||||
duration = meta_data["duration"] if "duration" in meta_data else total_original_frames / meta_data['fps']
|
||||
total_available_frames = math.floor(duration * self.frame_rate)
|
||||
return int(total_available_frames)
|
||||
|
||||
def get_num_frames(self, reader):
|
||||
num_frames = self.num_frames
|
||||
if int(reader.count_frames()) < num_frames:
|
||||
num_frames = int(reader.count_frames())
|
||||
total_frames = self.get_available_num_frames(reader)
|
||||
if int(total_frames) < num_frames:
|
||||
num_frames = total_frames
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
|
||||
def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int:
|
||||
if not self.fix_frame_rate:
|
||||
return new_sequence_id
|
||||
target_time_in_seconds = new_sequence_id / self.frame_rate
|
||||
raw_frame_index_float = target_time_in_seconds * raw_frame_rate
|
||||
frame_id = int(round(raw_frame_index_float))
|
||||
frame_id = min(frame_id, total_raw_frames - 1)
|
||||
return frame_id
|
||||
|
||||
|
||||
class LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False):
|
||||
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def __call__(self, data: str):
|
||||
reader = imageio.get_reader(data)
|
||||
reader = self.get_reader(data)
|
||||
raw_frame_rate = reader.get_meta_data()['fps']
|
||||
num_frames = self.get_num_frames(reader)
|
||||
total_raw_frames = reader.count_frames()
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames)
|
||||
frame = reader.get_data(frame_id)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.frame_processor(frame)
|
||||
@@ -149,7 +183,7 @@ class LoadGIF(DataProcessingOperator):
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
|
||||
def get_num_frames(self, path):
|
||||
num_frames = self.num_frames
|
||||
images = iio.imread(path, mode="RGB")
|
||||
@@ -220,14 +254,17 @@ class LoadAudio(DataProcessingOperator):
|
||||
return input_audio
|
||||
|
||||
|
||||
class LoadAudioWithTorchaudio(DataProcessingOperator):
|
||||
def __init__(self, duration=5):
|
||||
self.duration = duration
|
||||
class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
||||
|
||||
def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
|
||||
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
||||
|
||||
def __call__(self, data: str):
|
||||
import torchaudio
|
||||
reader = self.get_reader(data)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
duration = num_frames / self.frame_rate
|
||||
waveform, sample_rate = torchaudio.load(data)
|
||||
target_samples = int(self.duration * sample_rate)
|
||||
target_samples = int(duration * sample_rate)
|
||||
current_samples = waveform.shape[-1]
|
||||
if current_samples > target_samples:
|
||||
waveform = waveform[..., :target_samples]
|
||||
|
||||
@@ -42,6 +42,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||
frame_rate=24, fix_frame_rate=False,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
@@ -53,6 +54,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
frame_rate=frame_rate, fix_frame_rate=fix_frame_rate,
|
||||
)),
|
||||
])),
|
||||
])
|
||||
|
||||
@@ -436,7 +436,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
||||
return frame_conditions
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_images, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, use_two_stage_pipeline=False):
|
||||
if input_images is None or len(input_images) == 0:
|
||||
return {}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user