mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -141,7 +141,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
|
||||
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
|
||||
if skip_stage:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -160,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -231,7 +229,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
decoded_audio = self.audio_vocoder(decoded_audio)
|
||||
decoded_audio = self.output_audio_format_check(decoded_audio)
|
||||
return video, decoded_audio
|
||||
|
||||
|
||||
@@ -394,8 +393,8 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
|
||||
output_params=("denoise_mask_video", "input_latents_video"),
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
|
||||
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
@@ -418,20 +417,30 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||
return initial_latents, denoise_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, frame_rate, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
|
||||
if input_images is None or len(input_images) == 0:
|
||||
return {}
|
||||
else:
|
||||
if len(input_images_indexes) != len(set(input_images_indexes)):
|
||||
raise ValueError("Input images must have unique indexes.")
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
frame_conditions = {}
|
||||
frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
|
||||
for img, index in zip(input_images, input_images_indexes):
|
||||
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
||||
# first_frame
|
||||
# first_frame by replacing latents
|
||||
if index == 0:
|
||||
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
|
||||
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
||||
# other frames by adding reference latents
|
||||
else:
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()
|
||||
video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate
|
||||
video_positions = video_positions.to(pipe.torch_dtype)
|
||||
frame_conditions["ref_frames_latents"].append(latents)
|
||||
frame_conditions["ref_frames_positions"].append(video_positions)
|
||||
if len(frame_conditions["ref_frames_latents"]) == 0:
|
||||
frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None})
|
||||
return frame_conditions
|
||||
|
||||
|
||||
@@ -553,6 +562,9 @@ def model_fn_ltx2(
|
||||
# First Frame Conditioning
|
||||
input_latents_video=None,
|
||||
denoise_mask_video=None,
|
||||
# Other Frames Conditioning
|
||||
ref_frames_latents=None,
|
||||
ref_frames_positions=None,
|
||||
# In-Context Conditioning
|
||||
in_context_video_latents=None,
|
||||
in_context_video_positions=None,
|
||||
@@ -568,17 +580,24 @@ def model_fn_ltx2(
|
||||
video_latents = video_patchifier.patchify(video_latents)
|
||||
seq_len_video = video_latents.shape[1]
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
if denoise_mask_video is not None:
|
||||
# Frist frame conditioning by replacing the video latents
|
||||
if input_latents_video is not None:
|
||||
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
|
||||
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
|
||||
video_timesteps = denoise_mask_video * video_timesteps
|
||||
|
||||
if in_context_video_latents is not None:
|
||||
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)
|
||||
in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0.
|
||||
video_latents = torch.cat([video_latents, in_context_video_latents], dim=1)
|
||||
video_positions = torch.cat([video_positions, in_context_video_positions], dim=2)
|
||||
video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1)
|
||||
|
||||
# Conditioning by replacing the video latents
|
||||
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
|
||||
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
|
||||
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
|
||||
total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []
|
||||
if len(total_ref_latents) > 0:
|
||||
for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):
|
||||
ref_frames_latent = video_patchifier.patchify(ref_frames_latent)
|
||||
ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.
|
||||
video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)
|
||||
video_positions = torch.cat([video_positions, ref_frames_position], dim=2)
|
||||
video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)
|
||||
|
||||
if audio_latents is not None:
|
||||
_, c_a, _, mel_bins = audio_latents.shape
|
||||
|
||||
Reference in New Issue
Block a user