mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
refactor patchify
This commit is contained in:
@@ -167,6 +167,20 @@ class AudioPatchifier(Patchifier):
|
|||||||
|
|
||||||
return audio_latents
|
return audio_latents
|
||||||
|
|
||||||
|
def unpatchify_audio(
|
||||||
|
self,
|
||||||
|
audio_latents: torch.Tensor,
|
||||||
|
channels: int,
|
||||||
|
mel_bins: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
audio_latents = einops.rearrange(
|
||||||
|
audio_latents,
|
||||||
|
"b t (c f) -> b c t f",
|
||||||
|
c=channels,
|
||||||
|
f=mel_bins,
|
||||||
|
)
|
||||||
|
return audio_latents
|
||||||
|
|
||||||
def get_patch_grid_bounds(
|
def get_patch_grid_bounds(
|
||||||
self,
|
self,
|
||||||
output_shape: AudioLatentShape | VideoLatentShape,
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
|||||||
@@ -68,6 +68,24 @@ class VideoLatentPatchifier(Patchifier):
|
|||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
def unpatchify_video(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
latents = einops.rearrange(
|
||||||
|
latents,
|
||||||
|
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
||||||
|
f=frames,
|
||||||
|
h=height // self._patch_size[1],
|
||||||
|
w=width // self._patch_size[2],
|
||||||
|
p=self._patch_size[1],
|
||||||
|
q=self._patch_size[2],
|
||||||
|
)
|
||||||
|
return latents
|
||||||
|
|
||||||
def get_patch_grid_bounds(
|
def get_patch_grid_bounds(
|
||||||
self,
|
self,
|
||||||
output_shape: AudioLatentShape | VideoLatentShape,
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
|||||||
@@ -127,11 +127,9 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
|
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
|
||||||
latent = self.upsampler(latent)
|
latent = self.upsampler(latent)
|
||||||
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
||||||
latent = self.video_patchifier.patchify(latent)
|
|
||||||
self.scheduler.set_timesteps(special_case="stage2")
|
self.scheduler.set_timesteps(special_case="stage2")
|
||||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent
|
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent
|
||||||
inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"])
|
|
||||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||||
self.load_models_to_device(self.in_iteration_models)
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
if not inputs_shared["use_distilled_pipeline"]:
|
if not inputs_shared["use_distilled_pipeline"]:
|
||||||
@@ -147,8 +145,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
noise_pred=noise_pred_video, **inputs_shared)
|
noise_pred=noise_pred_video, **inputs_shared)
|
||||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||||
noise_pred=noise_pred_audio, **inputs_shared)
|
noise_pred=noise_pred_audio, **inputs_shared)
|
||||||
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
|
|
||||||
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
|
|
||||||
return inputs_shared
|
return inputs_shared
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -187,7 +183,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
||||||
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
||||||
# self.load_lora(self.dit, self.stage2_lora_path)
|
|
||||||
# Inputs
|
# Inputs
|
||||||
inputs_posi = {
|
inputs_posi = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@@ -203,6 +198,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||||
|
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -220,8 +216,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
noise_pred=noise_pred_video, **inputs_shared)
|
noise_pred=noise_pred_video, **inputs_shared)
|
||||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||||
noise_pred=noise_pred_audio, **inputs_shared)
|
noise_pred=noise_pred_audio, **inputs_shared)
|
||||||
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
|
|
||||||
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
|
|
||||||
|
|
||||||
# Denoise Stage 2
|
# Denoise Stage 2
|
||||||
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
||||||
@@ -422,7 +416,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||||
video_noise = pipe.video_patchifier.patchify(video_noise)
|
|
||||||
|
|
||||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||||
@@ -431,7 +424,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
|
|
||||||
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
||||||
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||||
audio_noise = pipe.audio_patchifier.patchify(audio_noise)
|
|
||||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||||
return {
|
return {
|
||||||
"video_noise": video_noise,
|
"video_noise": video_noise,
|
||||||
@@ -473,14 +465,21 @@ def model_fn_ltx2(
|
|||||||
video_latents=None,
|
video_latents=None,
|
||||||
video_context=None,
|
video_context=None,
|
||||||
video_positions=None,
|
video_positions=None,
|
||||||
|
video_patchifier=None,
|
||||||
audio_latents=None,
|
audio_latents=None,
|
||||||
audio_context=None,
|
audio_context=None,
|
||||||
audio_positions=None,
|
audio_positions=None,
|
||||||
|
audio_patchifier=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# patchify
|
||||||
|
b, c_v, f, h, w = video_latents.shape
|
||||||
|
_, c_a, _, mel_bins = audio_latents.shape
|
||||||
|
video_latents = video_patchifier.patchify(video_latents)
|
||||||
|
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||||
#TODO: support gradient checkpointing
|
#TODO: support gradient checkpointing
|
||||||
timestep = timestep.float() / 1000.
|
timestep = timestep.float() / 1000.
|
||||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||||
@@ -495,4 +494,7 @@ def model_fn_ltx2(
|
|||||||
audio_context=audio_context,
|
audio_context=audio_context,
|
||||||
audio_timesteps=audio_timesteps,
|
audio_timesteps=audio_timesteps,
|
||||||
)
|
)
|
||||||
|
# unpatchify
|
||||||
|
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||||
|
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|||||||
Reference in New Issue
Block a user