mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
ltx2.3 bugfix & ic lora (#1336)
* ltx2.3 ic lora inference&train * temp commit * fix first frame train-inference consistency * minor fix
This commit is contained in:
@@ -1336,45 +1336,30 @@ class LTX2VideoEncoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
if encoder_version == "ltx-2":
|
||||
encoder_blocks = [['res_x', {
|
||||
'num_layers': 4
|
||||
}], ['compress_space_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_time_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}]]
|
||||
encoder_blocks = [
|
||||
['res_x', {'num_layers': 4}],
|
||||
['compress_space_res', {'multiplier': 2}],
|
||||
['res_x', {'num_layers': 6}],
|
||||
['compress_time_res', {'multiplier': 2}],
|
||||
['res_x', {'num_layers': 6}],
|
||||
['compress_all_res', {'multiplier': 2}],
|
||||
['res_x', {'num_layers': 2}],
|
||||
['compress_all_res', {'multiplier': 2}],
|
||||
['res_x', {'num_layers': 2}]
|
||||
]
|
||||
else:
|
||||
encoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
# LTX-2.3
|
||||
encoder_blocks = [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all_res", {"multiplier": 1}],
|
||||
["res_x", {"num_layers": 2}]
|
||||
]
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
@@ -1816,48 +1801,28 @@ class LTX2VideoDecoder(nn.Module):
|
||||
# each spatial dimension (height and width). This parameter determines how
|
||||
# many video frames and pixels correspond to a single latent cell.
|
||||
if decoder_version == "ltx-2":
|
||||
decoder_blocks = [['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}]]
|
||||
decoder_blocks = [
|
||||
['res_x', {'num_layers': 5, 'inject_noise': False}],
|
||||
['compress_all', {'residual': True, 'multiplier': 2}],
|
||||
['res_x', {'num_layers': 5, 'inject_noise': False}],
|
||||
['compress_all', {'residual': True, 'multiplier': 2}],
|
||||
['res_x', {'num_layers': 5, 'inject_noise': False}],
|
||||
['compress_all', {'residual': True, 'multiplier': 2}],
|
||||
['res_x', {'num_layers': 5, 'inject_noise': False}]
|
||||
]
|
||||
else:
|
||||
decoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
# LTX-2.3
|
||||
decoder_blocks = [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_all", {"multiplier": 1}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}]
|
||||
]
|
||||
self.video_downscale_factors = SpatioTemporalScaleFactors(
|
||||
time=8,
|
||||
width=32,
|
||||
@@ -1877,15 +1842,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
self.decode_noise_scale = 0.025
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Compute initial feature_channels by going through blocks in reverse
|
||||
# This determines the channel width at the start of the decoder
|
||||
# feature_channels = in_channels
|
||||
# for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
# block_config = block_params if isinstance(block_params, dict) else {}
|
||||
# if block_name == "res_x_y":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
# if block_name == "compress_all":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
# LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2.
|
||||
# Hence the total feature_channels is multiplied by 8 (2^3).
|
||||
feature_channels = base_channels * 8
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
|
||||
@@ -108,18 +108,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
if inputs_shared["use_two_stage_pipeline"]:
|
||||
if inputs_shared.get("clear_lora_before_state_two", False):
|
||||
self.clear_lora()
|
||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
latents = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
self.load_models_to_device('upsampler',)
|
||||
latent = self.upsampler(latent)
|
||||
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
||||
latents = self.upsampler(latents)
|
||||
latents = self.video_vae_encoder.per_channel_statistics.normalize(latents)
|
||||
self.scheduler.set_timesteps(special_case="stage2")
|
||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||
denoise_mask_video = 1.0
|
||||
# input image
|
||||
if inputs_shared.get("input_images", None) is not None:
|
||||
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
|
||||
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
|
||||
inputs_shared["input_images_strength"], latent.clone())
|
||||
initial_latents, denoise_mask_video = self.apply_input_images_to_latents(latents, initial_latents=latents, **inputs_shared.get("stage2_input_latents_apply_kwargs", {}))
|
||||
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
|
||||
# remove in-context video control in stage 2
|
||||
inputs_shared.pop("in_context_video_latents", None)
|
||||
@@ -127,7 +125,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
# initialize latents for stage 2
|
||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
|
||||
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
|
||||
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latents
|
||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
|
||||
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||
|
||||
@@ -157,7 +155,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
denoising_strength: float = 1.0,
|
||||
# Image-to-video
|
||||
input_images: Optional[list[Image.Image]] = None,
|
||||
input_images_indexes: Optional[list[int]] = None,
|
||||
input_images_indexes: Optional[list[int]] = [0],
|
||||
input_images_strength: Optional[float] = 1.0,
|
||||
# In-Context Video Control
|
||||
in_context_videos: Optional[list[list[Image.Image]]] = None,
|
||||
@@ -238,17 +236,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
return video, decoded_audio
|
||||
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None):
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
|
||||
b, _, f, h, w = latents.shape
|
||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
|
||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
|
||||
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
||||
for idx, input_latent in zip(input_indexes, input_latents):
|
||||
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
||||
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
||||
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
||||
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
|
||||
return latents, denoise_mask, initial_latents
|
||||
return initial_latents, denoise_mask
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
@@ -414,7 +411,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
||||
output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"),
|
||||
output_params=("denoise_mask_video", "input_latents_video", "stage2_input_latents_apply_kwargs"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
@@ -423,29 +420,39 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
image = image / 127.5 - 1.0
|
||||
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
|
||||
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
||||
return latent
|
||||
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
||||
return latents
|
||||
|
||||
def get_frame_conditions(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents=None, skip_apply=False):
|
||||
frame_conditions = {}
|
||||
for img, index in zip(input_images, input_images_indexes):
|
||||
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
||||
# first_frame
|
||||
if index == 0:
|
||||
if skip_apply:
|
||||
frame_conditions = {"input_latents": [latents], "input_indexes": [0], "input_strength": input_images_strength}
|
||||
else:
|
||||
input_latents_video, denoise_mask_video = pipe.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength)
|
||||
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
||||
return frame_conditions
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
|
||||
if input_images is None or len(input_images) == 0:
|
||||
return {"video_latents": video_latents}
|
||||
return {}
|
||||
else:
|
||||
if len(input_images_indexes) != len(set(input_images_indexes)):
|
||||
raise ValueError("Input images must have unique indexes.")
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
output_dicts = {}
|
||||
stage1_height = height // 2 if use_two_stage_pipeline else height
|
||||
stage1_width = width // 2 if use_two_stage_pipeline else width
|
||||
stage1_latents = [
|
||||
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength)
|
||||
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
|
||||
stage_1_frame_conditions = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, stage1_height, stage1_width,
|
||||
tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents)
|
||||
output_dicts.update(stage_1_frame_conditions)
|
||||
if use_two_stage_pipeline:
|
||||
stage2_latents = [
|
||||
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
output_dicts.update({"stage2_input_latents": stage2_latents})
|
||||
stage2_input_latents_apply_kwargs = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, height, width,
|
||||
tiled, tile_size_in_pixels, tile_overlap_in_pixels, skip_apply=True)
|
||||
output_dicts.update({"stage2_input_latents_apply_kwargs": stage2_input_latents_apply_kwargs})
|
||||
return output_dicts
|
||||
|
||||
|
||||
@@ -508,6 +515,7 @@ def model_fn_ltx2(
|
||||
audio_positions=None,
|
||||
audio_patchifier=None,
|
||||
timestep=None,
|
||||
input_latents_video=None,
|
||||
denoise_mask_video=None,
|
||||
in_context_video_latents=None,
|
||||
in_context_video_positions=None,
|
||||
@@ -523,7 +531,9 @@ def model_fn_ltx2(
|
||||
seq_len_video = video_latents.shape[1]
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
if denoise_mask_video is not None:
|
||||
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
||||
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
|
||||
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
|
||||
video_timesteps = denoise_mask_video * video_timesteps
|
||||
|
||||
if in_context_video_latents is not None:
|
||||
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)
|
||||
|
||||
Reference in New Issue
Block a user