mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'main' into layercontrol_v2
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from .model_configs import MODEL_CONFIGS
|
||||
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
|
||||
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
|
||||
|
||||
@@ -598,7 +598,14 @@ z_image_series = [
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
"""
|
||||
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||
and avoid redundant memory usage when users only want to use part of the model.
|
||||
"""
|
||||
ltx2_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
@@ -607,6 +614,13 @@ ltx2_series = [
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
|
||||
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
|
||||
"model_name": "ltx2_dit",
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
@@ -614,6 +628,13 @@ ltx2_series = [
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
||||
"model_hash": "7f7e904a53260ec0351b05f32153754b",
|
||||
"model_name": "ltx2_video_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
@@ -621,6 +642,13 @@ ltx2_series = [
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
||||
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
|
||||
"model_name": "ltx2_video_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
@@ -628,6 +656,13 @@ ltx2_series = [
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
|
||||
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
||||
"model_name": "ltx2_audio_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
@@ -635,16 +670,37 @@ ltx2_series = [
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||
},
|
||||
# { # not used currently
|
||||
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
# "model_name": "ltx2_audio_vae_encoder",
|
||||
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
# },
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
||||
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
|
||||
"model_name": "ltx2_audio_vocoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
||||
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
||||
"model_name": "ltx2_audio_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_text_encoder_post_modules",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
||||
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
|
||||
"model_name": "ltx2_text_encoder_post_modules",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||
@@ -663,4 +719,20 @@ ltx2_series = [
|
||||
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||
},
|
||||
]
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||
anima_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
|
||||
"model_hash": "a9995952c2d8e63cf82e115005eb61b9",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
"extra_kwargs": {"model_size": "0.6B"},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
|
||||
"model_hash": "417673936471e79e31ed4d186d7a3f4a",
|
||||
"model_name": "anima_dit",
|
||||
"model_class": "diffsynth.models.anima_dit.AnimaDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
|
||||
}
|
||||
]
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series
|
||||
|
||||
@@ -243,4 +243,24 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.anima_dit.AnimaDiT": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
def QwenImageTextEncoder_Module_Map_Updater():
|
||||
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
|
||||
from packaging import version
|
||||
import transformers
|
||||
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
|
||||
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
|
||||
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
|
||||
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
|
||||
return current
|
||||
|
||||
VERSION_CHECKER_MAPS = {
|
||||
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
|
||||
}
|
||||
@@ -218,3 +218,20 @@ class LoadAudio(DataProcessingOperator):
|
||||
import librosa
|
||||
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||
return input_audio
|
||||
|
||||
|
||||
class LoadAudioWithTorchaudio(DataProcessingOperator):
|
||||
def __init__(self, duration=5):
|
||||
self.duration = duration
|
||||
|
||||
def __call__(self, data: str):
|
||||
import torchaudio
|
||||
waveform, sample_rate = torchaudio.load(data)
|
||||
target_samples = int(self.duration * sample_rate)
|
||||
current_samples = waveform.shape[-1]
|
||||
if current_samples > target_samples:
|
||||
waveform = waveform[..., :target_samples]
|
||||
elif current_samples < target_samples:
|
||||
padding = target_samples - current_samples
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
return waveform, sample_rate
|
||||
|
||||
@@ -94,20 +94,23 @@ class BasePipeline(torch.nn.Module):
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if verbose > 0:
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if verbose > 0:
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
if verbose > 0:
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,36 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
return loss
|
||||
|
||||
|
||||
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
# video
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
# audio
|
||||
if inputs.get("audio_input_latents") is not None:
|
||||
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
||||
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
||||
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
||||
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
if inputs.get("audio_input_latents") is not None:
|
||||
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
||||
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
||||
loss = loss + loss_audio
|
||||
return loss
|
||||
|
||||
|
||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
pipe.scheduler.training = True
|
||||
@@ -91,7 +121,9 @@ class TrajectoryImitationLoss(torch.nn.Module):
|
||||
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||
latents_ = trajectory_teacher[progress_id_teacher]
|
||||
|
||||
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
||||
denom = sigma_ - sigma
|
||||
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
|
||||
target = (latents_ - inputs_shared["latents"]) / denom
|
||||
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
1304
diffsynth/models/anima_dit.py
Normal file
1304
diffsynth/models/anima_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,65 @@ import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
|
||||
|
||||
|
||||
class AudioProcessor(nn.Module):
|
||||
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
mel_bins: int = 64,
|
||||
mel_hop_length: int = 160,
|
||||
n_fft: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=n_fft,
|
||||
hop_length=mel_hop_length,
|
||||
f_min=0.0,
|
||||
f_max=sample_rate / 2.0,
|
||||
n_mels=mel_bins,
|
||||
window_fn=torch.hann_window,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
power=1.0,
|
||||
mel_scale="slaney",
|
||||
norm="slaney",
|
||||
)
|
||||
|
||||
def resample_waveform(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
source_rate: int,
|
||||
target_rate: int,
|
||||
) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
||||
|
||||
def waveform_to_mel(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
waveform_sample_rate: int,
|
||||
) -> torch.Tensor:
|
||||
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
||||
waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
|
||||
|
||||
mel = self.mel_transform(waveform)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
|
||||
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
||||
return mel.permute(0, 1, 3, 2).contiguous()
|
||||
|
||||
|
||||
class AudioPatchifier(Patchifier):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from einops import rearrange
|
||||
from .ltx2_common import rms_norm, Modality
|
||||
from ..core.attention.attention import attention_forward
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -1352,28 +1353,21 @@ class LTXModel(torch.nn.Module):
|
||||
video: TransformerArgs | None,
|
||||
audio: TransformerArgs | None,
|
||||
perturbations: BatchedPerturbationConfig,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
) -> tuple[TransformerArgs, TransformerArgs]:
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
|
||||
# Process transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if self._enable_gradient_checkpointing and self.training:
|
||||
# Use gradient checkpointing to save memory during training.
|
||||
# With use_reentrant=False, we can pass dataclasses directly -
|
||||
# PyTorch will track all tensor leaves in the computation graph.
|
||||
video, audio = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
video,
|
||||
audio,
|
||||
perturbations,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
video, audio = block(
|
||||
video=video,
|
||||
audio=audio,
|
||||
perturbations=perturbations,
|
||||
)
|
||||
video, audio = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
video=video,
|
||||
audio=audio,
|
||||
perturbations=perturbations,
|
||||
)
|
||||
|
||||
return video, audio
|
||||
|
||||
@@ -1398,7 +1392,12 @@ class LTXModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def _forward(
|
||||
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
||||
self,
|
||||
video: Modality | None,
|
||||
audio: Modality | None,
|
||||
perturbations: BatchedPerturbationConfig,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for LTX models.
|
||||
@@ -1417,6 +1416,8 @@ class LTXModel(torch.nn.Module):
|
||||
video=video_args,
|
||||
audio=audio_args,
|
||||
perturbations=perturbations,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
# Process output
|
||||
@@ -1440,12 +1441,12 @@ class LTXModel(torch.nn.Module):
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
cross_pe_max_pos = None
|
||||
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
||||
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
||||
self._init_preprocessors(cross_pe_max_pos)
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
|
||||
return vx, ax
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ..core.loader import load_model, hash_model_file
|
||||
from ..core.vram import AutoWrappedModule
|
||||
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
|
||||
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
|
||||
import importlib, json, torch
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ class ModelPool:
|
||||
def fetch_module_map(self, model_class, vram_config):
|
||||
if self.need_to_enable_vram_management(vram_config):
|
||||
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
|
||||
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
|
||||
vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()
|
||||
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}
|
||||
else:
|
||||
module_map = {self.import_model_class(model_class): AutoWrappedModule}
|
||||
else:
|
||||
|
||||
@@ -469,7 +469,7 @@ class Down_ResidualBlock(nn.Module):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
||||
|
||||
@@ -506,10 +506,10 @@ class Up_ResidualBlock(nn.Module):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
x_main = x.clone()
|
||||
for module in self.upsamples:
|
||||
x_main = module(x_main, feat_cache, feat_idx)
|
||||
x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx)
|
||||
if self.avg_shortcut is not None:
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
return x_main + x_shortcut, feat_cache, feat_idx
|
||||
else:
|
||||
return x_main, feat_cache, feat_idx
|
||||
|
||||
|
||||
263
diffsynth/pipelines/anima_image.py
Normal file
263
diffsynth/pipelines/anima_image.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import torch, math
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from math import prod
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
from ..utils.lora.merge import merge_lora
|
||||
|
||||
from ..models.anima_dit import AnimaDiT
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
|
||||
|
||||
class AnimaImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("Z-Image")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AnimaDiT = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.tokenizer_t5xxl: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AnimaUnit_ShapeChecker(),
|
||||
AnimaUnit_NoiseInitializer(),
|
||||
AnimaUnit_InputImageEmbedder(),
|
||||
AnimaUnit_PromptEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_anima
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("anima_dit")
|
||||
pipe.vae = model_pool.fetch_model("wan_video_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
if tokenizer_t5xxl_config is not None:
|
||||
tokenizer_t5xxl_config.download_if_necessary()
|
||||
pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path)
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
sigma_shift: float = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"].unsqueeze(2), device=self.device).squeeze(2)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AnimaUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width"),
|
||||
output_params=("height", "width"),
|
||||
)
|
||||
|
||||
def process(self, pipe: AnimaImagePipeline, height, width):
|
||||
height, width = pipe.check_resize_height_width(height, width)
|
||||
return {"height": height, "width": width}
|
||||
|
||||
|
||||
|
||||
class AnimaUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
|
||||
class AnimaUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AnimaImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
if isinstance(input_image, list):
|
||||
input_latents = []
|
||||
for image in input_image:
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents.append(pipe.vae.encode(image))
|
||||
input_latents = torch.concat(input_latents, dim=0)
|
||||
else:
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class AnimaUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_emb",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
pipe: AnimaImagePipeline,
|
||||
prompt,
|
||||
device = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
text_inputs = pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = pipe.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-1]
|
||||
|
||||
t5xxl_text_inputs = pipe.tokenizer_t5xxl(
|
||||
prompt,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
t5xxl_ids = t5xxl_text_inputs.input_ids.to(device)
|
||||
|
||||
return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids
|
||||
|
||||
def process(self, pipe: AnimaImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
return {"prompt_emb": prompt_embeds, "t5xxl_ids": t5xxl_ids}
|
||||
|
||||
|
||||
def model_fn_anima(
|
||||
dit: AnimaDiT = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
t5xxl_ids=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs
|
||||
):
|
||||
latents = latents.unsqueeze(2)
|
||||
timestep = timestep / 1000
|
||||
model_output = dit(
|
||||
x=latents,
|
||||
timesteps=timestep,
|
||||
context=prompt_emb,
|
||||
t5xxl_ids=t5xxl_ids,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
model_output = model_output.squeeze(2)
|
||||
return model_output
|
||||
@@ -18,7 +18,7 @@ from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||
from ..models.ltx2_dit import LTXModel
|
||||
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor
|
||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||
@@ -50,6 +50,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||
self.audio_processor: AudioProcessor = AudioProcessor()
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
@@ -57,8 +58,10 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
LTX2AudioVideoUnit_ShapeChecker(),
|
||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_InputAudioEmbedder(),
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||
LTX2AudioVideoUnit_InContextVideoEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
|
||||
@@ -95,7 +98,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
# Optional, currently not used
|
||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
@@ -103,6 +106,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
||||
if inputs_shared["use_two_stage_pipeline"]:
|
||||
if inputs_shared.get("clear_lora_before_state_two", False):
|
||||
self.clear_lora()
|
||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
self.load_models_to_device('upsampler',)
|
||||
latent = self.upsampler(latent)
|
||||
@@ -110,11 +115,17 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
self.scheduler.set_timesteps(special_case="stage2")
|
||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||
denoise_mask_video = 1.0
|
||||
# input image
|
||||
if inputs_shared.get("input_images", None) is not None:
|
||||
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
|
||||
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
|
||||
inputs_shared["input_images_strength"], latent.clone())
|
||||
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
|
||||
# remove in-context video control in stage 2
|
||||
inputs_shared.pop("in_context_video_latents", None)
|
||||
inputs_shared.pop("in_context_video_positions", None)
|
||||
|
||||
# initialize latents for stage 2
|
||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
|
||||
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
|
||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
|
||||
@@ -143,11 +154,14 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
denoising_strength: float = 1.0,
|
||||
# Image-to-video
|
||||
input_images: Optional[list[Image.Image]] = None,
|
||||
input_images_indexes: Optional[list[int]] = None,
|
||||
input_images_strength: Optional[float] = 1.0,
|
||||
# In-Context Video Control
|
||||
in_context_videos: Optional[list[list[Image.Image]]] = None,
|
||||
in_context_downsample_factor: Optional[int] = 2,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
@@ -155,9 +169,9 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 768,
|
||||
num_frames=121,
|
||||
frame_rate=24,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 3.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 40,
|
||||
# VAE tiling
|
||||
@@ -168,6 +182,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
# Special Pipelines
|
||||
use_two_stage_pipeline: Optional[bool] = False,
|
||||
clear_lora_before_state_two: Optional[bool] = False,
|
||||
use_distilled_pipeline: Optional[bool] = False,
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
@@ -184,12 +199,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
||||
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
||||
"cfg_scale": cfg_scale,
|
||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two,
|
||||
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
||||
}
|
||||
for unit in self.units:
|
||||
@@ -416,13 +432,13 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
|
||||
output_params=("video_noise", "audio_noise",),
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"),
|
||||
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
|
||||
)
|
||||
|
||||
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)
|
||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||
@@ -455,23 +471,51 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
|
||||
output_params=("video_latents", "input_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||
if input_video is None:
|
||||
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
||||
return {"video_latents": video_noise}
|
||||
else:
|
||||
# TODO: implement video-to-video
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if pipe.scheduler.training:
|
||||
return {"video_latents": input_latents, "input_latents": input_latents}
|
||||
else:
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
|
||||
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_audio", "audio_noise"),
|
||||
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
|
||||
onload_model_names=("audio_vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
|
||||
if input_audio is None:
|
||||
return {"audio_latents": audio_noise}
|
||||
else:
|
||||
input_audio, sample_rate = input_audio
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
|
||||
audio_input_latents = pipe.audio_vae_encoder(input_audio)
|
||||
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
if pipe.scheduler.training:
|
||||
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
|
||||
else:
|
||||
raise NotImplementedError("Audio-to-video not supported.")
|
||||
|
||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
||||
output_params=("video_latents"),
|
||||
output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
@@ -506,6 +550,54 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
return output_dicts
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
||||
output_params=("in_context_video_latents", "in_context_video_positions"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True):
|
||||
if in_context_video is None or len(in_context_video) == 0:
|
||||
raise ValueError("In-context video is None or empty.")
|
||||
in_context_video = in_context_video[:num_frames]
|
||||
expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor
|
||||
expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor
|
||||
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
|
||||
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
|
||||
if current_h != h or current_w != w:
|
||||
in_context_video = [img.resize((w, h)) for img in in_context_video]
|
||||
if current_f != f:
|
||||
# pad black frames at the end
|
||||
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
|
||||
return in_context_video
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True):
|
||||
if in_context_videos is None or len(in_context_videos) == 0:
|
||||
return {}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
latents, positions = [], []
|
||||
for in_context_video in in_context_videos:
|
||||
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline)
|
||||
in_context_video = pipe.preprocess_video(in_context_video)
|
||||
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
||||
video_positions[:, 1, ...] *= in_context_downsample_factor # height axis
|
||||
video_positions[:, 2, ...] *= in_context_downsample_factor # width axis
|
||||
video_positions = video_positions.to(pipe.torch_dtype)
|
||||
|
||||
latents.append(in_context_latents)
|
||||
positions.append(video_positions)
|
||||
latents = torch.cat(latents, dim=1)
|
||||
positions = torch.cat(positions, dim=1)
|
||||
return {"in_context_video_latents": latents, "in_context_video_positions": positions}
|
||||
|
||||
|
||||
def model_fn_ltx2(
|
||||
dit: LTXModel,
|
||||
video_latents=None,
|
||||
@@ -518,6 +610,8 @@ def model_fn_ltx2(
|
||||
audio_patchifier=None,
|
||||
timestep=None,
|
||||
denoise_mask_video=None,
|
||||
in_context_video_latents=None,
|
||||
in_context_video_positions=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
@@ -527,13 +621,25 @@ def model_fn_ltx2(
|
||||
# patchify
|
||||
b, c_v, f, h, w = video_latents.shape
|
||||
video_latents = video_patchifier.patchify(video_latents)
|
||||
seq_len_video = video_latents.shape[1]
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
if denoise_mask_video is not None:
|
||||
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
||||
_, c_a, _, mel_bins = audio_latents.shape
|
||||
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||
#TODO: support gradient checkpointing in training
|
||||
|
||||
if in_context_video_latents is not None:
|
||||
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)
|
||||
in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0.
|
||||
video_latents = torch.cat([video_latents, in_context_video_latents], dim=1)
|
||||
video_positions = torch.cat([video_positions, in_context_video_positions], dim=2)
|
||||
video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1)
|
||||
|
||||
if audio_latents is not None:
|
||||
_, c_a, _, mel_bins = audio_latents.shape
|
||||
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||
else:
|
||||
audio_timesteps = None
|
||||
|
||||
vx, ax = dit(
|
||||
video_latents=video_latents,
|
||||
video_positions=video_positions,
|
||||
@@ -543,8 +649,12 @@ def model_fn_ltx2(
|
||||
audio_positions=audio_positions,
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
vx = vx[:, :seq_len_video, ...]
|
||||
# unpatchify
|
||||
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
||||
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None
|
||||
return vx, ax
|
||||
|
||||
@@ -299,7 +299,7 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
def process(self, pipe: ZImagePipeline, prompt, edit_image):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
|
||||
if hasattr(pipe, "dit") and pipe.dit is not None and pipe.dit.siglip_embedder is not None:
|
||||
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
|
||||
# We determine which encoding method to use based on the model architecture.
|
||||
# If you are using two-stage split training,
|
||||
|
||||
@@ -116,7 +116,7 @@ class VideoData:
|
||||
if self.height is not None and self.width is not None:
|
||||
return self.height, self.width
|
||||
else:
|
||||
height, width, _ = self.__getitem__(0).shape
|
||||
width, height = self.__getitem__(0).size
|
||||
return height, width
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch
|
||||
import torch, warnings
|
||||
|
||||
|
||||
class GeneralLoRALoader:
|
||||
@@ -26,7 +26,11 @@ class GeneralLoRALoader:
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key))
|
||||
# Alpha: Deprecated but retained for compatibility.
|
||||
key_alpha = key.replace(lora_B_key + ".weight", "alpha").replace(lora_B_key + ".default.weight", "alpha")
|
||||
if key_alpha == key or key_alpha not in lora_state_dict:
|
||||
key_alpha = None
|
||||
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha)
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
@@ -36,6 +40,10 @@ class GeneralLoRALoader:
|
||||
for name in name_dict:
|
||||
weight_up = state_dict[name_dict[name][0]]
|
||||
weight_down = state_dict[name_dict[name][1]]
|
||||
if name_dict[name][2] is not None:
|
||||
warnings.warn("Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.")
|
||||
alpha = state_dict[name_dict[name][2]] / weight_down.shape[0]
|
||||
weight_down = weight_down * alpha
|
||||
state_dict_[name + f".lora_B{suffix}"] = weight_up
|
||||
state_dict_[name + f".lora_A{suffix}"] = weight_down
|
||||
return state_dict_
|
||||
|
||||
6
diffsynth/utils/state_dict_converters/anima_dit.py
Normal file
6
diffsynth/utils/state_dict_converters/anima_dit.py
Normal file
@@ -0,0 +1,6 @@
|
||||
def AnimaDiTStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
value = state_dict[key]
|
||||
new_state_dict[key.replace("net.", "")] = value
|
||||
return new_state_dict
|
||||
Reference in New Issue
Block a user