mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
refactor patchify
This commit is contained in:
@@ -127,11 +127,9 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
|
||||
latent = self.upsampler(latent)
|
||||
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
||||
latent = self.video_patchifier.patchify(latent)
|
||||
self.scheduler.set_timesteps(special_case="stage2")
|
||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"])
|
||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
if not inputs_shared["use_distilled_pipeline"]:
|
||||
@@ -147,8 +145,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
|
||||
return inputs_shared
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -187,7 +183,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
||||
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
||||
# self.load_lora(self.dit, self.stage2_lora_path)
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
@@ -203,6 +198,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -220,8 +216,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
|
||||
|
||||
# Denoise Stage 2
|
||||
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
||||
@@ -422,7 +416,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
video_noise = pipe.video_patchifier.patchify(video_noise)
|
||||
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||
@@ -431,7 +424,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
|
||||
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
||||
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
audio_noise = pipe.audio_patchifier.patchify(audio_noise)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
return {
|
||||
"video_noise": video_noise,
|
||||
@@ -473,14 +465,21 @@ def model_fn_ltx2(
|
||||
video_latents=None,
|
||||
video_context=None,
|
||||
video_positions=None,
|
||||
video_patchifier=None,
|
||||
audio_latents=None,
|
||||
audio_context=None,
|
||||
audio_positions=None,
|
||||
audio_patchifier=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
# patchify
|
||||
b, c_v, f, h, w = video_latents.shape
|
||||
_, c_a, _, mel_bins = audio_latents.shape
|
||||
video_latents = video_patchifier.patchify(video_latents)
|
||||
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||
#TODO: support gradient checkpointing
|
||||
timestep = timestep.float() / 1000.
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
@@ -495,4 +494,7 @@ def model_fn_ltx2(
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
)
|
||||
# unpatchify
|
||||
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
||||
return vx, ax
|
||||
|
||||
Reference in New Issue
Block a user