mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2 two stage pipeline & vram
This commit is contained in:
@@ -60,10 +60,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
]
|
||||
self.post_units = [
|
||||
LTX2AudioVideoPostUnit_UnPatchifier(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
# self.lora_loader = LTX2LoRALoader
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
@@ -71,6 +69,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config: Optional[ModelConfig] = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
@@ -90,14 +89,67 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
# Optional
|
||||
|
||||
# Stage 2
|
||||
if stage2_lora_config is not None:
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
# Optional, currently not used
|
||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline=True, progress_bar_cmd=tqdm):
|
||||
if use_two_stage_pipeline:
|
||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
|
||||
latent = self.upsampler(latent)
|
||||
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
||||
latent = self.video_patchifier.patchify(latent)
|
||||
self.scheduler.set_timesteps(special_case="stage2")
|
||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"])
|
||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
|
||||
return inputs_shared
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
@@ -126,14 +178,15 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
# Two-Stage Pipeline
|
||||
use_two_stage: Optional[bool] = True,
|
||||
# Special Pipelines
|
||||
use_two_stage_pipeline: Optional[bool] = False,
|
||||
use_distilled_pipeline: Optional[bool] = False,
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
|
||||
# self.load_lora(self.dit, self.stage2_lora_path)
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
@@ -148,14 +201,12 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||
"use_two_stage": True
|
||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt"))
|
||||
# inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt"))
|
||||
# Denoise
|
||||
# Denoise Stage 1
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
@@ -164,12 +215,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
|
||||
|
||||
# Denoise Stage 2
|
||||
inputs_shared = self.stage2_denoise(cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline, progress_bar_cmd)
|
||||
|
||||
# post-denoising, pre-decoding processing logic
|
||||
for unit in self.post_units:
|
||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae_decoder'])
|
||||
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
||||
@@ -181,39 +236,21 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
return video, decoded_audio
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
pass
|
||||
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
||||
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
"""
|
||||
# TODO: Adjust with two stage pipeline
|
||||
For two-stage pipelines, the resolution must be divisible by 64.
|
||||
For one-stage pipelines, the resolution must be divisible by 32.
|
||||
"""
|
||||
@@ -223,8 +260,14 @@ class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
output_params=("height", "width", "num_frames"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames):
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
|
||||
if use_two_stage_pipeline:
|
||||
self.width_division_factor = 64
|
||||
self.height_division_factor = 64
|
||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||
if use_two_stage_pipeline:
|
||||
self.width_division_factor = 32
|
||||
self.height_division_factor = 32
|
||||
return {"height": height, "width": width, "num_frames": num_frames}
|
||||
|
||||
|
||||
@@ -327,10 +370,12 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||
|
||||
def _preprocess_text(self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
def _preprocess_text(
|
||||
self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Encode a given string into feature tensors suitable for downstream tasks.
|
||||
Args:
|
||||
@@ -339,8 +384,8 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||
"""
|
||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device)
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
projected = self._run_feature_extractor(pipe,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@@ -362,11 +407,11 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device",),
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
|
||||
output_params=("video_noise", "audio_noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
@@ -390,6 +435,15 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
"audio_latent_shape": audio_latent_shape
|
||||
}
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
|
||||
if use_two_stage_pipeline:
|
||||
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
|
||||
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||
initial_dict = stage1_dict
|
||||
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
|
||||
return initial_dict
|
||||
else:
|
||||
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||
|
||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
@@ -407,20 +461,6 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
|
||||
|
||||
class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents):
|
||||
video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape)
|
||||
audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape)
|
||||
return {"video_latents": video_latents, "audio_latents": audio_latents}
|
||||
|
||||
|
||||
def model_fn_ltx2(
|
||||
dit: LTXModel,
|
||||
video_latents=None,
|
||||
|
||||
Reference in New Issue
Block a user