mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Ltx2.3 a2v& retake video and audio (#1346)
* temp commit * support ltx2 a2v * support ltx2.3 retake video and audio * add news * minor fix
This commit is contained in:
@@ -58,6 +58,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
LTX2AudioVideoUnit_ShapeChecker(),
|
||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
|
||||
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
|
||||
LTX2AudioVideoUnit_InputAudioEmbedder(),
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||
@@ -67,8 +69,10 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
LTX2AudioVideoUnit_SwitchStage2(),
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_LatentsUpsampler(),
|
||||
LTX2AudioVideoUnit_SetScheduleStage2(),
|
||||
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
|
||||
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
|
||||
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||
LTX2AudioVideoUnit_SetScheduleStage2(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
|
||||
@@ -156,7 +160,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
||||
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio,
|
||||
inpaint_mask=inputs_shared.get("denoise_mask_audio", None), input_latents=inputs_shared.get("input_latents_audio", None), **inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -173,6 +178,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
# In-Context Video Control
|
||||
in_context_videos: Optional[list[list[Image.Image]]] = None,
|
||||
in_context_downsample_factor: Optional[int] = 2,
|
||||
# Video-to-video
|
||||
retake_video: Optional[list[Image.Image]] = None,
|
||||
retake_video_regions: Optional[list[tuple[float, float]]] = None,
|
||||
# Audio-to-video
|
||||
retake_audio: Optional[torch.Tensor] = None,
|
||||
audio_sample_rate: Optional[int] = 48000,
|
||||
retake_audio_regions: Optional[list[tuple[float, float]]] = None,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
@@ -210,6 +222,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
||||
"retake_video": retake_video, "retake_video_regions": retake_video_regions,
|
||||
"retake_audio": (retake_audio, audio_sample_rate) if retake_audio is not None else None, "retake_audio_regions": retake_audio_regions,
|
||||
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
||||
@@ -354,17 +368,13 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||
if input_video is None:
|
||||
if input_video is None or not pipe.scheduler.training:
|
||||
return {"video_latents": video_noise}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if pipe.scheduler.training:
|
||||
return {"video_latents": input_latents, "input_latents": input_latents}
|
||||
else:
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
|
||||
return {"video_latents": input_latents, "input_latents": input_latents}
|
||||
|
||||
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
@@ -375,7 +385,7 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
|
||||
if input_audio is None:
|
||||
if input_audio is None or not pipe.scheduler.training:
|
||||
return {"audio_latents": audio_noise}
|
||||
else:
|
||||
input_audio, sample_rate = input_audio
|
||||
@@ -384,16 +394,83 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
audio_input_latents = pipe.audio_vae_encoder(input_audio)
|
||||
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
if pipe.scheduler.training:
|
||||
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
|
||||
else:
|
||||
raise NotImplementedError("Audio-to-video not supported.")
|
||||
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_VideoRetakeEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("retake_video", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "video_positions", "retake_video_regions"),
|
||||
output_params=("input_latents_video", "denoise_mask_video"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, retake_video, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_positions, retake_video_regions=None):
|
||||
if retake_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
resized_video = [frame.resize((width, height)) for frame in retake_video]
|
||||
input_video = pipe.preprocess_video(resized_video)
|
||||
input_latents_video = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
b, c, f, h, w = input_latents_video.shape
|
||||
denoise_mask_video = torch.zeros((b, 1, f, h, w), device=input_latents_video.device, dtype=input_latents_video.dtype)
|
||||
if retake_video_regions is not None and len(retake_video_regions) > 0:
|
||||
for start_time, end_time in retake_video_regions:
|
||||
t_start, t_end = video_positions[0, 0].unbind(dim=-1)
|
||||
in_region = (t_end >= start_time) & (t_start <= end_time)
|
||||
in_region = pipe.video_patchifier.unpatchify_video(in_region.unsqueeze(0).unsqueeze(-1), f, h, w)
|
||||
denoise_mask_video = torch.where(in_region, torch.ones_like(denoise_mask_video), denoise_mask_video)
|
||||
|
||||
return {"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):
|
||||
"""
|
||||
Functionality of audio2video, audio retaking.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("retake_audio", "seed", "rand_device", "retake_audio_regions"),
|
||||
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),
|
||||
onload_model_names=("audio_vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, retake_audio, seed, rand_device, retake_audio_regions=None):
|
||||
if retake_audio is None:
|
||||
return {}
|
||||
else:
|
||||
input_audio, sample_rate = retake_audio
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
input_latents_audio = pipe.audio_vae_encoder(input_audio)
|
||||
audio_latent_shape = AudioLatentShape.from_torch_shape(input_latents_audio.shape)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
# Regenerate noise for the new shape if retake_audio is provided, to avoid shape mismatch.
|
||||
audio_noise = pipe.generate_noise(input_latents_audio.shape, seed=seed, rand_device=rand_device)
|
||||
|
||||
b, c, t, f = input_latents_audio.shape
|
||||
denoise_mask_audio = torch.zeros((b, 1, t, 1), device=input_latents_audio.device, dtype=input_latents_audio.dtype)
|
||||
if retake_audio_regions is not None and len(retake_audio_regions) > 0:
|
||||
for start_time, end_time in retake_audio_regions:
|
||||
t_start, t_end = audio_positions[:, 0, :, 0], audio_positions[:, 0, :, 1]
|
||||
in_region = (t_end >= start_time) & (t_start <= end_time)
|
||||
in_region = pipe.audio_patchifier.unpatchify_audio(in_region.unsqueeze(-1), 1, 1)
|
||||
denoise_mask_audio = torch.where(in_region, torch.ones_like(denoise_mask_audio), denoise_mask_audio)
|
||||
|
||||
return {
|
||||
"input_latents_audio": input_latents_audio,
|
||||
"denoise_mask_audio": denoise_mask_audio,
|
||||
"audio_noise": audio_noise,
|
||||
"audio_positions": audio_positions,
|
||||
"audio_latent_shape": audio_latent_shape,
|
||||
}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "input_latents_video", "denoise_mask_video"),
|
||||
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
@@ -406,18 +483,33 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
||||
return latents
|
||||
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, input_latents_video=None, denoise_mask_video=None):
|
||||
b, _, f, h, w = latents.shape
|
||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
|
||||
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
||||
input_latents_video = torch.zeros_like(latents) if input_latents_video is None else input_latents_video
|
||||
for idx, input_latent in zip(input_indexes, input_latents):
|
||||
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
||||
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
||||
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
||||
input_latents_video[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
||||
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||
return initial_latents, denoise_mask
|
||||
return input_latents_video, denoise_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, frame_rate, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
|
||||
def process(
|
||||
self,
|
||||
pipe: LTX2AudioVideoPipeline,
|
||||
video_latents,
|
||||
input_images,
|
||||
height,
|
||||
width,
|
||||
frame_rate,
|
||||
tiled,
|
||||
tile_size_in_pixels,
|
||||
tile_overlap_in_pixels,
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
input_latents_video=None,
|
||||
denoise_mask_video=None,
|
||||
):
|
||||
if input_images is None or len(input_images) == 0:
|
||||
return {}
|
||||
else:
|
||||
@@ -429,7 +521,8 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
||||
# first_frame by replacing latents
|
||||
if index == 0:
|
||||
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
|
||||
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(
|
||||
video_latents, [latents], [0], input_images_strength, input_latents_video, denoise_mask_video)
|
||||
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
||||
# other frames by adding reference latents
|
||||
else:
|
||||
@@ -508,6 +601,7 @@ class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
|
||||
stage2_params = {}
|
||||
stage2_params.update({"height": stage_2_height, "width": stage_2_width})
|
||||
stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None})
|
||||
stage2_params.update({"input_latents_video": None, "denoise_mask_video": None})
|
||||
if clear_lora_before_state_two:
|
||||
pipe.clear_lora()
|
||||
if not use_distilled_pipeline:
|
||||
@@ -533,7 +627,7 @@ class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("video_latents",),
|
||||
output_params=("video_latents", "initial_latents"),
|
||||
output_params=("video_latents",),
|
||||
onload_model_names=("upsampler",),
|
||||
)
|
||||
|
||||
@@ -545,7 +639,7 @@ class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):
|
||||
video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents)
|
||||
video_latents = pipe.upsampler(video_latents)
|
||||
video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents)
|
||||
return {"video_latents": video_latents, "initial_latents": video_latents}
|
||||
return {"video_latents": video_latents}
|
||||
|
||||
|
||||
def model_fn_ltx2(
|
||||
@@ -568,6 +662,9 @@ def model_fn_ltx2(
|
||||
# In-Context Conditioning
|
||||
in_context_video_latents=None,
|
||||
in_context_video_positions=None,
|
||||
# Audio Inputs
|
||||
input_latents_audio=None,
|
||||
denoise_mask_audio=None,
|
||||
# Gradient Checkpointing
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
@@ -585,8 +682,8 @@ def model_fn_ltx2(
|
||||
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
|
||||
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
|
||||
video_timesteps = denoise_mask_video * video_timesteps
|
||||
|
||||
# Conditioning by replacing the video latents
|
||||
|
||||
# Reference conditioning by appending the reference video or frame latents
|
||||
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
|
||||
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
|
||||
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
|
||||
@@ -605,6 +702,10 @@ def model_fn_ltx2(
|
||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||
else:
|
||||
audio_timesteps = None
|
||||
if input_latents_audio is not None:
|
||||
denoise_mask_audio = audio_patchifier.patchify(denoise_mask_audio)
|
||||
audio_latents = audio_latents * denoise_mask_audio + audio_patchifier.patchify(input_latents_audio) * (1.0 - denoise_mask_audio)
|
||||
audio_timesteps = denoise_mask_audio * audio_timesteps
|
||||
|
||||
vx, ax = dit(
|
||||
video_latents=video_latents,
|
||||
|
||||
Reference in New Issue
Block a user