mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -147,6 +147,19 @@ class BasePipeline(torch.nn.Module):
|
||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||
return video
|
||||
|
||||
def output_audio_format_check(self, audio_output):
|
||||
# output standard foramt: [C, T], output dtype: float()
|
||||
# remove batch dim
|
||||
if audio_output.ndim == 3:
|
||||
audio_output = audio_output.squeeze(0)
|
||||
# Transform to stereo
|
||||
if audio_output.shape[0] == 1:
|
||||
audio_output = audio_output.repeat(2, 1)
|
||||
elif audio_output.shape[0] == 2:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("The output audio should be [C, T] or [1, C, T] or [2, C, T].")
|
||||
return audio_output.float()
|
||||
|
||||
def load_models_to_device(self, model_names):
|
||||
if self.vram_management_enabled:
|
||||
|
||||
@@ -141,7 +141,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
|
||||
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
|
||||
if skip_stage:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -160,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -231,7 +229,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
decoded_audio = self.audio_vocoder(decoded_audio)
|
||||
decoded_audio = self.output_audio_format_check(decoded_audio)
|
||||
return video, decoded_audio
|
||||
|
||||
|
||||
@@ -394,8 +393,8 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
|
||||
output_params=("denoise_mask_video", "input_latents_video"),
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
|
||||
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
@@ -418,20 +417,30 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||
return initial_latents, denoise_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, frame_rate, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
|
||||
if input_images is None or len(input_images) == 0:
|
||||
return {}
|
||||
else:
|
||||
if len(input_images_indexes) != len(set(input_images_indexes)):
|
||||
raise ValueError("Input images must have unique indexes.")
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
frame_conditions = {}
|
||||
frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
|
||||
for img, index in zip(input_images, input_images_indexes):
|
||||
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
||||
# first_frame
|
||||
# first_frame by replacing latents
|
||||
if index == 0:
|
||||
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
|
||||
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
||||
# other frames by adding reference latents
|
||||
else:
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()
|
||||
video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate
|
||||
video_positions = video_positions.to(pipe.torch_dtype)
|
||||
frame_conditions["ref_frames_latents"].append(latents)
|
||||
frame_conditions["ref_frames_positions"].append(video_positions)
|
||||
if len(frame_conditions["ref_frames_latents"]) == 0:
|
||||
frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None})
|
||||
return frame_conditions
|
||||
|
||||
|
||||
@@ -553,6 +562,9 @@ def model_fn_ltx2(
|
||||
# First Frame Conditioning
|
||||
input_latents_video=None,
|
||||
denoise_mask_video=None,
|
||||
# Other Frames Conditioning
|
||||
ref_frames_latents=None,
|
||||
ref_frames_positions=None,
|
||||
# In-Context Conditioning
|
||||
in_context_video_latents=None,
|
||||
in_context_video_positions=None,
|
||||
@@ -568,17 +580,24 @@ def model_fn_ltx2(
|
||||
video_latents = video_patchifier.patchify(video_latents)
|
||||
seq_len_video = video_latents.shape[1]
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
if denoise_mask_video is not None:
|
||||
# Frist frame conditioning by replacing the video latents
|
||||
if input_latents_video is not None:
|
||||
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
|
||||
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
|
||||
video_timesteps = denoise_mask_video * video_timesteps
|
||||
|
||||
if in_context_video_latents is not None:
|
||||
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)
|
||||
in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0.
|
||||
video_latents = torch.cat([video_latents, in_context_video_latents], dim=1)
|
||||
video_positions = torch.cat([video_positions, in_context_video_positions], dim=2)
|
||||
video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1)
|
||||
|
||||
# Conditioning by replacing the video latents
|
||||
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
|
||||
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
|
||||
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
|
||||
total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []
|
||||
if len(total_ref_latents) > 0:
|
||||
for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):
|
||||
ref_frames_latent = video_patchifier.patchify(ref_frames_latent)
|
||||
ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.
|
||||
video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)
|
||||
video_positions = torch.cat([video_positions, ref_frames_position], dim=2)
|
||||
video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)
|
||||
|
||||
if audio_latents is not None:
|
||||
_, c_a, _, mel_bins = audio_latents.shape
|
||||
|
||||
@@ -43,13 +43,10 @@ def _write_audio(
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
if samples.shape[0] == 1:
|
||||
samples = samples.repeat(2, 1)
|
||||
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
|
||||
samples = samples.T
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
@@ -69,10 +66,17 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream = container.add_stream("aac")
|
||||
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
|
||||
if supported_sample_rates:
|
||||
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||
if best_rate != audio_sample_rate:
|
||||
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||
else:
|
||||
best_rate = audio_sample_rate
|
||||
audio_stream.codec_context.sample_rate = best_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
audio_stream.codec_context.time_base = Fraction(1, best_rate)
|
||||
return audio_stream
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
@@ -80,8 +84,31 @@ def write_video_audio_ltx2(
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = 24000,
|
||||
audio_sample_rate: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Writes a sequence of images and an audio tensor to a video file.
|
||||
|
||||
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
|
||||
and multiplex a PyTorch tensor as the audio stream into the output container.
|
||||
|
||||
Args:
|
||||
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
|
||||
The length of this list determines the total duration of the video based on the FPS.
|
||||
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
|
||||
The shape is typically (channels, samples). If no audio is required, pass None.
|
||||
channels can be 1 or 2. 1 for mono, 2 for stereo.
|
||||
output_path (str): The file path (including extension) where the output video will be saved.
|
||||
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
|
||||
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
|
||||
If the audio tensor is provided and this is None, the function attempts to infer the rate
|
||||
based on the audio tensor's length and the video duration.
|
||||
Raises:
|
||||
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
|
||||
"""
|
||||
duration = len(video) / fps
|
||||
if audio_sample_rate is None:
|
||||
audio_sample_rate = int(audio.shape[-1] / duration)
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
|
||||
@@ -26,7 +26,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
@@ -41,12 +41,9 @@ negative_prompt = (
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset")
|
||||
first_frame = Image.open("data/example_video_dataset/ltx2/first_frame.png").convert("RGB").resize((width, height))
|
||||
last_frame = Image.open("data/example_video_dataset/ltx2/last_frame.png").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
@@ -57,7 +54,7 @@ video, audio = pipe(
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
input_images=[image],
|
||||
input_images=[first_frame],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
@@ -66,5 +63,26 @@ write_video_audio_ltx2(
|
||||
audio=audio,
|
||||
output_path='ltx2.3_twostage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
pipe.clear_lora()
|
||||
|
||||
# This example uses the first and last frames for demonstration. However, you can use any frames by setting input_images and input_indexes. Note that input_indexes must be within the range of num_frames.
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=42,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
input_images=[first_frame, last_frame],
|
||||
input_images_indexes=[0, num_frames-1],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_twostage_i2av_first_last.mp4',
|
||||
fps=24,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user