DiffSynth-Studio 2.0 major update

This commit is contained in:
root
2025-12-04 16:33:07 +08:00
parent afd101f345
commit 72af7122b3
758 changed files with 26462 additions and 2221398 deletions

View File

@@ -1,15 +0,0 @@
from .sd_image import SDImagePipeline
from .sd_video import SDVideoPipeline
from .sdxl_image import SDXLImagePipeline
from .sdxl_video import SDXLVideoPipeline
from .sd3_image import SD3ImagePipeline
from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline
from .flux_image import FluxImagePipeline
from .cog_video import CogVideoPipeline
from .omnigen_image import OmnigenImagePipeline
from .pipeline_runner import SDVideoPipelineRunner
from .hunyuan_video import HunyuanVideoPipeline
from .step_video import StepVideoPipeline
from .wan_video import WanVideoPipeline
KolorsImagePipeline = SDXLImagePipeline

View File

@@ -1,127 +0,0 @@
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.cpu_offload = False
self.model_names = []
def check_resize_height_width(self, height, width):
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
return height, width
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def preprocess_images(self, images):
return [self.preprocess_image(image) for image in images]
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def vae_output_to_video(self, vae_output):
video = vae_output.cpu().permute(1, 2, 0).numpy()
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
return video
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
if len(latents) > 0:
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
if special_kwargs is None:
noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
else:
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -1,135 +0,0 @@
from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder
from ..prompters import CogPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
class CogVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
self.prompter = CogPrompter()
# models
self.text_encoder: FluxTextEncoder2 = None
self.dit: CogDiT = None
self.vae_encoder: CogVAEEncoder = None
self.vae_decoder: CogVAEDecoder = None
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("cog_dit")
self.vae_encoder = model_manager.fetch_model("cog_vae_encoder")
self.vae_decoder = model_manager.fetch_model("cog_vae_decoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
pipe = CogVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def encode_prompt(self, prompt, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
return {"prompt_emb": prompt_emb}
def prepare_extra_input(self, latents):
return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)}
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
cfg_scale=7.0,
denoising_strength=1.0,
num_frames=49,
height=480,
width=720,
num_inference_steps=20,
tiled=False,
tile_size=(60, 90),
tile_stride=(30, 45),
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Prepare latent tensors
noise = self.generate_noise((1, 16, num_frames // 4 + 1, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
if denoising_strength == 1.0:
latents = noise.clone()
else:
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
if not tiled: latents = latents.to(self.device)
# Encode prompt
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Update progress bar
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd)
video = self.tensor2video(video[0])
return video

View File

@@ -1,236 +0,0 @@
import torch
from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
from ..models.sd_unet import PushBlock, PopBlock
from ..controlnets import MultiControlNetManager
def lets_dance(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
ipadapter_kwargs_list = {},
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
tiled=False,
tile_size=64,
tile_stride=32,
device = "cuda",
vram_limit_level = 0,
):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
# 1. ControlNet
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
# I leave it here because I intend to do something interesting on the ControlNets.
controlnet_insert_block_id = 30
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
time_emb = unet.time_proj(timestep).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
# 4. blocks
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
motion_module_id = motion_modules.call_block_id[block_id]
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
if vram_limit_level>=1:
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
else:
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)
hidden_states = unet.conv_act(hidden_states)
hidden_states = unet.conv_out(hidden_states)
return hidden_states
def lets_dance_xl(
unet: SDXLUNet,
motion_modules: SDXLMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
add_time_id = None,
add_text_embeds = None,
timestep = None,
encoder_hidden_states = None,
ipadapter_kwargs_list = {},
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
tiled=False,
tile_size=64,
tile_stride=32,
device = "cuda",
vram_limit_level = 0,
):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
if add_text_embeds.shape[0] != sample.shape[0]:
add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
# 1. ControlNet
controlnet_insert_block_id = 22
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
add_time_id=add_time_id,
add_text_embeds=add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
unet=unet, # for Kolors, some modules in ControlNets will be replaced.
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
t_emb = unet.time_proj(timestep).to(sample.dtype)
t_emb = unet.time_embedding(t_emb)
time_embeds = unet.add_time_proj(add_time_id)
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(sample.dtype)
add_embeds = unet.add_time_embedding(add_embeds)
time_emb = t_emb + add_embeds
# 3. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
res_stack = [hidden_states]
# 4. blocks
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb[batch_id: batch_id_],
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
motion_module_id = motion_modules.call_block_id[block_id]
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)
hidden_states = unet.conv_act(hidden_states)
hidden_states = unet.conv_out(hidden_states)
return hidden_states

View File

@@ -0,0 +1,370 @@
import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoProcessor
from ..models.flux2_text_encoder import Flux2TextEncoder
from ..models.flux2_dit import Flux2DiT
from ..models.flux2_vae import Flux2VAE
class Flux2ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: Flux2TextEncoder = None
self.dit: Flux2DiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
Flux2Unit_ShapeChecker(),
Flux2Unit_PromptEmbedder(),
Flux2Unit_NoiseInitializer(),
Flux2Unit_InputImageEmbedder(),
Flux2Unit_ImageIDs(),
]
self.model_fn = model_fn_flux2
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = Flux2ImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
pipe.dit = model_pool.fetch_model("flux2_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 4.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
latents = rearrange(inputs_shared["latents"], "B (H W) C -> B C H W", H=inputs_shared["height"]//16, W=inputs_shared["width"]//16)
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class Flux2Unit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: Flux2ImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class Flux2Unit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_emb", "prompt_emb_mask"),
onload_model_names=("text_encoder",)
)
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
def format_text_input(self, prompts: List[str], system_message: str = None):
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
def get_mistral_3_small_prompt_embeds(
self,
text_encoder,
tokenizer,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
# fmt: off
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
# fmt: on
hidden_states_layers: List[int] = (10, 20, 30),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
# Format input messages
messages_batch = self.format_text_input(prompts=prompt, system_message=system_message)
# Process all messages at once
inputs = tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
# Move to device
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
def prepare_text_ids(
self,
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: Optional[torch.Tensor] = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
self,
text_encoder,
tokenizer,
prompt: Union[str, List[str]],
dtype = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self.get_mistral_3_small_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
dtype=dtype,
device=device,
max_sequence_length=max_sequence_length,
system_message=self.system_message,
hidden_states_layers=text_encoder_out_layers,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = self.prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
def process(self, pipe: Flux2ImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, text_ids = self.encode_prompt(
pipe.text_encoder, pipe.tokenizer, prompt,
dtype=pipe.torch_dtype, device=pipe.device,
)
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
class Flux2Unit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
return {"noise": noise}
class Flux2Unit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: Flux2ImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
input_latents = rearrange(input_latents, "B C H W -> B (H W) C")
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
class Flux2Unit_ImageIDs(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("image_ids",),
)
def prepare_latent_ids(self, height, width):
t = torch.arange(1) # [0] - time dimension
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1) # [0] - layer dimension
# Create position IDs: (H*W, 4)
latent_ids = torch.cartesian_prod(t, h, w, l)
# Expand to batch: (B, H*W, 4)
latent_ids = latent_ids.unsqueeze(0).expand(1, -1, -1)
return latent_ids
def process(self, pipe: Flux2ImagePipeline, height, width):
image_ids = self.prepare_latent_ids(height // 16, width // 16).to(pipe.device)
return {"image_ids": image_ids}
def model_fn_flux2(
dit: Flux2DiT,
latents=None,
timestep=None,
embedded_guidance=None,
prompt_embeds=None,
text_ids=None,
image_ids=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
model_output = dit(
hidden_states=latents,
timestep=timestep / 1000,
guidance=embedded_guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=image_ids,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,288 +0,0 @@
from ..models.hunyuan_dit import HunyuanDiT
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models import ModelManager
from ..prompters import HunyuanDiTPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
import numpy as np
class ImageSizeManager:
def __init__(self):
pass
def _to_tuple(self, x):
if isinstance(x, int):
return x, x
else:
return x
def get_fill_resize_and_crop(self, src, tgt):
th, tw = self._to_tuple(tgt)
h, w = self._to_tuple(src)
tr = th / tw # base 分辨率
r = h / w # 目标分辨率
# resize
if r > tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h)) # 根据base分辨率将目标分辨率resize下来
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_meshgrid(self, start, *args):
if len(args) == 0:
# start is grid_size
num = self._to_tuple(start)
start = (0, 0)
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = self._to_tuple(start)
stop = self._to_tuple(args[0])
num = (stop[0] - start[0], stop[1] - start[1])
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = self._to_tuple(start) # 左上角 eg: 12,0
stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
grid = self.get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def calc_rope(self, height, width):
patch_size = 2
head_size = 88
th = height // 8 // patch_size
tw = width // 8 // patch_size
base_size = 512 // 8 // patch_size
start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
class HunyuanDiTImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
self.prompter = HunyuanDiTPrompter()
self.image_size_manager = ImageSizeManager()
# models
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
self.dit: HunyuanDiT = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
self.dit = model_manager.fetch_model("hunyuan_dit")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = HunyuanDiTImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=positive,
device=self.device
)
return {
"text_emb": text_emb,
"text_emb_mask": text_emb_mask,
"text_emb_t5": text_emb_t5,
"text_emb_mask_t5": text_emb_mask_t5
}
def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
if tiled:
height, width = tile_size * 16, tile_size * 16
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
freqs_cis_img = self.image_size_manager.calc_rope(height, width)
image_meta_size = torch.stack([image_meta_size] * batch_size)
return {
"size_emb": image_meta_size,
"freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride
}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=1,
input_image=None,
reference_strengths=[0.4],
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise.clone()
# Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
# Prepare positional id
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
# Positive side
inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **extra_input,
)
# Classifier-free guidance
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# Offload all models
self.load_models_to_device([])
return image

View File

@@ -1,395 +0,0 @@
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
import torchvision.transforms as transforms
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.vae_encoder: HunyuanVideoVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
self.vram_management = False
def enable_vram_management(self):
self.vram_management = True
self.enable_cpu_offload()
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.dit = model_manager.fetch_model("hunyuan_video_dit")
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
if enable_vram_management:
pipe.enable_vram_management()
return pipe
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
target_height, target_width = closest_size
return semantic_image_pixel_values, target_height, target_width
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
def prepare_extra_input(self, latents=None, guidance=1.0):
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
input_images=None,
i2v_resolution="720p",
i2v_stability=True,
denoising_strength=1.0,
seed=None,
rand_device=None,
height=720,
width=1280,
num_frames=129,
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
tea_cache_l1_thresh=None,
tile_size=(17, 30, 30),
tile_stride=(12, 20, 20),
step_processor=None,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# encoder input images
if input_images is not None:
self.load_models_to_device(['vae_encoder'])
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
image_latents = self.vae_encoder(image_pixel_values)
# Initialize noise
rand_device = self.device if rand_device is None else rand_device
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
if input_video is not None:
self.load_models_to_device(['vae_encoder'])
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
elif input_images is not None and i2v_stability:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
t = torch.tensor([0.999]).to(device=self.device)
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
latents = latents.to(dtype=image_latents.dtype)
else:
latents = noise
# Encode prompts
# current mllm does not support vram_management
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device([] if self.vram_management else ["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
forward_func = lets_dance_hunyuan_video
if input_images is not None:
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
forward_func = lets_dance_hunyuan_video_i2v
# Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0:
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# (Experimental feature, may be removed in the future)
if step_processor is not None:
self.load_models_to_device(['vae_decoder'])
rendered_frames = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents, to_final=True)
rendered_frames = self.vae_decoder.decode_video(rendered_frames, **tiler_kwargs)
rendered_frames = self.tensor2video(rendered_frames[0])
rendered_frames = step_processor(rendered_frames, original_frames=input_video)
self.load_models_to_device(['vae_encoder'])
rendered_frames = self.preprocess_images(rendered_frames)
rendered_frames = torch.stack(rendered_frames, dim=2)
target_latents = self.encode_video(rendered_frames).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.scheduler.return_to_timestep(self.scheduler.timesteps[progress_id], latents, target_latents)
self.load_models_to_device([] if self.vram_management else ["dit"])
# Scheduler
if input_images is not None:
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
latents = torch.concat([image_latents, latents], dim=2)
else:
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])
return frames
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
def check(self, dit: HunyuanVideoDiT, img, vec):
img_ = img.clone()
vec_ = vec.clone()
img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
normed_inp = dit.double_blocks[0].component_a.norm1(img_)
modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = img.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_hunyuan_video(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
img = x[:, :-txt_len]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def lets_dance_hunyuan_video_i2v(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
# Uncomment below to keep same as official implementation
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
vec = dit.time_in(t, dtype=torch.bfloat16)
vec_2 = dit.vector_in(pooled_prompt_emb)
vec = vec + vec_2
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
tr_token = (H // 2) * (W // 2)
token_replace_vec = token_replace_vec + vec_2
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
img = x[:, :-txt_len]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img

View File

@@ -1,289 +0,0 @@
from ..models.omnigen import OmniGenTransformer
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.model_manager import ModelManager
from ..prompters.omnigen_prompter import OmniGenPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import Optional, Dict, Any, Tuple, List
from transformers.cache_utils import DynamicCache
import torch, os
from tqdm import tqdm
class OmniGenCache(DynamicCache):
def __init__(self,
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
if not torch.cuda.is_available():
print("No available GPU, offload_kv_cache will be set to False, which will result in large memory usage and time cost when input multiple images!!!")
offload_kv_cache = False
raise RuntimeError("OffloadedCache can only be used with a GPU")
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.num_tokens_for_img = num_tokens_for_img
self.offload_kv_cache = offload_kv_cache
def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
if layer_idx == 0:
prev_layer_idx = -1
else:
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
if self.offload_kv_cache:
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
# self.prefetch_stream.synchronize(original_device)
torch.cuda.synchronize(self.prefetch_stream)
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
else:
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the cache
if len(self.key_cache) < layer_idx:
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
# only cache the states for condition tokens
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
if self.offload_kv_cache:
self.evict_previous_layer(layer_idx)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
else:
# only cache the states for condition tokens
key_tensor, value_tensor = self[layer_idx]
k = torch.cat([key_tensor, key_states], dim=-2)
v = torch.cat([value_tensor, value_states], dim=-2)
return k, v
class OmnigenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
# models
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.transformer: OmniGenTransformer = None
self.prompter: OmniGenPrompter = None
self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.transformer
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
# Main models
self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = OmnigenImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes=[])
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
return {"encoder_hidden_states": prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
if isinstance(position_ids, list):
for i in range(len(position_ids)):
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
else:
position_ids = position_ids[:, -(num_tokens_for_img+1):]
return position_ids
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
if isinstance(attention_mask, list):
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
return attention_mask[..., -(num_tokens_for_img+1):, :]
@torch.no_grad()
def __call__(
self,
prompt,
reference_images=[],
cfg_scale=2.0,
image_cfg_scale=2.0,
use_kv_cache=True,
offload_kv_cache=True,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = latents.repeat(3, 1, 1, 1)
# Encode prompts
input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
# Encode images
reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
# Pack all parameters
model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
input_img_latents=reference_latents,
input_image_sizes=input_data['input_image_sizes'],
attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
cfg_scale=cfg_scale,
img_cfg_scale=image_cfg_scale,
use_img_cfg=True,
use_kv_cache=use_kv_cache,
offload_model=False,
)
# Denoise
self.load_models_to_device(['transformer'])
cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
# Forward
noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Update KV cache
if progress_id == 0 and use_kv_cache:
num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
if isinstance(cache, list):
model_kwargs['input_ids'] = [None] * len(cache)
else:
model_kwargs['input_ids'] = None
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
del cache
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -1,105 +0,0 @@
import os, torch, json
from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
from ..processors.sequencial_processor import SequencialProcessor
from ..data import VideoData, save_frames, save_video
class SDVideoPipelineRunner:
def __init__(self, in_streamlit=False):
self.in_streamlit = in_streamlit
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_models(model_list)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id=unit["processor_id"],
model_path=unit["model_path"],
scale=unit["scale"]
) for unit in controlnet_units
]
)
textual_inversion_paths = []
for file_name in os.listdir(textual_inversion_folder):
if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
pipe.prompter.load_textual_inversions(textual_inversion_paths)
return model_manager, pipe
def load_smoother(self, model_manager, smoother_configs):
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
return smoother
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
torch.manual_seed(seed)
if self.in_streamlit:
import streamlit as st
progress_bar_st = st.progress(0.0)
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
progress_bar_st.progress(1.0)
else:
output_video = pipe(**pipeline_inputs, smoother=smoother)
model_manager.to("cpu")
return output_video
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
if start_frame_id is None:
start_frame_id = 0
if end_frame_id is None:
end_frame_id = len(video)
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
return frames
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
if len(data["controlnet_frames"]) > 0:
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
return pipeline_inputs
def save_output(self, video, output_folder, fps, config):
os.makedirs(output_folder, exist_ok=True)
save_frames(video, os.path.join(output_folder, "frames"))
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
with open(os.path.join(output_folder, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
def run(self, config):
if self.in_streamlit:
import streamlit as st
if self.in_streamlit: st.markdown("Loading videos ...")
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Loading videos ... done!")
if self.in_streamlit: st.markdown("Loading models ...")
model_manager, pipe = self.load_pipeline(**config["models"])
if self.in_streamlit: st.markdown("Loading models ... done!")
if "smoother_configs" in config:
if self.in_streamlit: st.markdown("Loading smoother ...")
smoother = self.load_smoother(model_manager, config["smoother_configs"])
if self.in_streamlit: st.markdown("Loading smoother ... done!")
else:
smoother = None
if self.in_streamlit: st.markdown("Synthesizing videos ...")
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
if self.in_streamlit: st.markdown("Saving videos ...")
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
if self.in_streamlit: st.markdown("Saving videos ... done!")
if self.in_streamlit: st.markdown("Finished!")
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
if self.in_streamlit: st.video(video_file.read())

View File

@@ -1,48 +1,18 @@
import torch
import torch, math
from PIL import Image
from typing import Union
from PIL import Image
from tqdm import tqdm
from einops import rearrange
import numpy as np
from ..models import ModelManager, load_state_dict
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
from ..schedulers import FlowMatchScheduler
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora import GeneralLoRALoader
from .flux_image_new import ControlNetInput
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageBlockWiseControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
processed_conditionings = []
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
processed_conditionings.append(model_output)
return processed_conditionings
def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
res = 0
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
res = res + model_output * controlnet_input.scale
return res
class QwenImagePipeline(BasePipeline):
@@ -54,14 +24,13 @@ class QwenImagePipeline(BasePipeline):
)
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
self.scheduler = FlowMatchScheduler("Qwen-Image")
self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None
self.processor: Qwen2VLProcessor = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "blockwise_controlnet")
self.units = [
QwenImageUnit_ShapeChecker(),
@@ -75,245 +44,6 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_BlockwiseControlNet(),
]
self.model_fn = model_fn_qwen_image
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
state_dict=None,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def clear_lora(self):
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def enable_lora_magic(self):
if self.dit is not None:
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=None,
)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def direct_distill_loss(self, **inputs):
self.scheduler.set_timesteps(inputs["num_inference_steps"])
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(self.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
return loss
def _enable_fp8_lora_training(self, dtype):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
from ..models.qwen_image_dit import RMSNorm
from ..models.qwen_image_vae import QwenImageRMS_norm
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
QwenImageRMS_norm: AutoWrappedModule,
}
model_config = dict(
offload_dtype=dtype,
offload_device="cuda",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device="cuda",
)
if self.text_encoder is not None:
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
if self.dit is not None:
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
if self.vae is not None:
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, auto_offload=True, enable_dit_fp8_computation=False):
self.vram_management_enabled = True
if vram_limit is None and auto_offload:
vram_limit = self.get_vram()
if vram_limit is not None:
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
self.text_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
if not enable_dit_fp8_computation:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
else:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
QwenImageRMS_norm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.blockwise_controlnet is not None:
enable_vram_management(
self.blockwise_controlnet,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@staticmethod
@@ -323,24 +53,18 @@ class QwenImagePipeline(BasePipeline):
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,
vram_limit: float = None,
):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary()
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype)
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_manager.fetch_model("qwen_image_blockwise_controlnet", index="all"))
if tokenizer_config is not None and pipe.text_encoder is not None:
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("qwen_image_text_encoder")
pipe.dit = model_pool.fetch_model("qwen_image_dit")
pipe.vae = model_pool.fetch_model("qwen_image_vae")
pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_pool.fetch_model("qwen_image_blockwise_controlnet", index="all"))
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
@@ -348,6 +72,9 @@ class QwenImagePipeline(BasePipeline):
processor_config.download_if_necessary()
from transformers import Qwen2VLProcessor
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@@ -386,8 +113,6 @@ class QwenImagePipeline(BasePipeline):
edit_rope_interpolation: bool = False,
# In-context control
context_image: Image.Image = None,
# FP8
enable_fp8_attention: bool = False,
# Tile
tiled: bool = False,
tile_size: int = 128,
@@ -411,7 +136,6 @@ class QwenImagePipeline(BasePipeline):
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"num_inference_steps": num_inference_steps,
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
@@ -427,16 +151,11 @@ class QwenImagePipeline(BasePipeline):
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
if cfg_scale != 1.0:
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
@@ -448,10 +167,41 @@ class QwenImagePipeline(BasePipeline):
return image
class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageBlockWiseControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
for model in models:
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
self.vram_management_enabled = True
def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
processed_conditionings = []
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
processed_conditionings.append(model_output)
return processed_conditionings
def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
res = 0
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
res = res + model_output * controlnet_input.scale
return res
class QwenImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width"))
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: QwenImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
@@ -461,7 +211,10 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
class QwenImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "seed", "rand_device"))
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
@@ -473,6 +226,7 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
@@ -494,6 +248,7 @@ class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
output_params=("inpaint_mask",),
)
def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
@@ -515,6 +270,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
input_params=("edit_image",),
output_params=("prompt_emb", "prompt_emb_mask"),
onload_model_names=("text_encoder",)
)
@@ -526,7 +282,6 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
return split_result
def calculate_dimensions(self, target_area, ratio):
import math
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
@@ -573,6 +328,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
return split_hidden_states
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
pipe.load_models_to_device(self.onload_model_names)
if pipe.text_encoder is not None:
prompt = [prompt]
if edit_image is None:
@@ -595,6 +351,8 @@ class QwenImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"),
output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"),
onload_model_names=("text_encoder",)
)
@@ -675,6 +433,7 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"),
output_params=("blockwise_controlnet_conditioning",),
onload_model_names=("vae",)
)
@@ -717,6 +476,7 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"),
output_params=("edit_latents", "edit_image"),
onload_model_names=("vae",)
)
@@ -738,7 +498,7 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
if edit_image is None:
return {}
pipe.load_models_to_device(['vae'])
pipe.load_models_to_device(self.onload_model_names)
if isinstance(edit_image, Image.Image):
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
@@ -759,13 +519,14 @@ class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("context_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
if context_image is None:
return {}
pipe.load_models_to_device(['vae'])
pipe.load_models_to_device(self.onload_model_names)
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"context_latents": context_latents}

View File

@@ -1,147 +0,0 @@
from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
from ..prompters import SD3Prompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
class SD3ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = FlowMatchScheduler()
self.prompter = SD3Prompter()
# models
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: SD3TextEncoder2 = None
self.text_encoder_3: SD3TextEncoder3 = None
self.dit: SD3DiT = None
self.vae_decoder: SD3VAEDecoder = None
self.vae_encoder: SD3VAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
self.dit = model_manager.fetch_model("sd3_dit")
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = SD3ImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
t5_sequence_length=77,
tiled=False,
tile_size=128,
tile_stride=64,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -1,191 +0,0 @@
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
class SDImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDPrompter()
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self):
return self.unet
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
self.unet = model_manager.fetch_model("sd_unet")
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sd_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
return {"encoder_hidden_states": prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
# IP-Adapter
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device,
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -1,269 +0,0 @@
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .sd_image import SDImagePipeline
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
def lets_dance_with_long_video(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
ipadapter_kwargs_list = {},
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
tiled=False,
tile_size=64,
tile_stride=32,
device="cuda",
animatediff_batch_size=16,
animatediff_stride=8,
):
num_frames = sample.shape[0]
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
for batch_id in range(0, num_frames, animatediff_stride):
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
# process this batch
hidden_states_batch = lets_dance(
unet, motion_modules, controlnet,
sample[batch_id: batch_id_].to(device),
timestep,
encoder_hidden_states,
ipadapter_kwargs_list=ipadapter_kwargs_list,
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
).cpu()
# update hidden_states
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
hidden_states, num = hidden_states_output[i]
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
hidden_states_output[i] = (hidden_states, num + bias)
if batch_id_ == num_frames:
break
# output
hidden_states = torch.stack([h for h, _ in hidden_states_output])
return hidden_states
class SDVideoPipeline(SDImagePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
self.prompter = SDPrompter()
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
self.motion_modules: SDMotionModel = None
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
self.unet = model_manager.fetch_model("sd_unet")
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sd_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
# Motion Modules
self.motion_modules = model_manager.fetch_model("sd_motion_modules")
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents.append(latent.cpu())
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
other_kwargs = {
"animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
"unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
"cross_frame_attention": cross_frame_attention,
}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_video(input_frames, **tiler_kwargs)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
# IP-Adapter
if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
device=self.device,
)
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_video(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_video(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
output_frames = self.decode_video(latents, **tiler_kwargs)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return output_frames

View File

@@ -1,226 +0,0 @@
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.kolors_text_encoder import ChatGLMModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDXLPrompter, KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
from .dancer import lets_dance_xl
from typing import List
import torch
from tqdm import tqdm
from einops import repeat
class SDXLImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDXLPrompter()
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.text_encoder_kolors: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self):
return self.unet
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
self.unet = model_manager.fetch_model("sdxl_unet")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sdxl_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
# Kolors
if self.text_encoder_kolors is not None:
print("Switch to Kolors. The prompter and scheduler will be replaced.")
self.prompter = KolorsPrompter()
self.prompter.fetch_models(self.text_encoder_kolors)
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
else:
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDXLImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=positive,
)
return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
def prepare_extra_input(self, latents=None):
height, width = latents.shape[2] * 8, latents.shape[3] * 8
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
return {"add_time_id": add_time_id}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
controlnet_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: lets_dance_xl(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device,
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -1,226 +0,0 @@
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
from ..models.kolors_text_encoder import ChatGLMModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDXLPrompter, KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .sdxl_image import SDXLImagePipeline
from .dancer import lets_dance_xl
from typing import List
import torch
from tqdm import tqdm
class SDXLVideoPipeline(SDXLImagePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
self.prompter = SDXLPrompter()
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.text_encoder_kolors: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# self.controlnet: MultiControlNetManager = None (TODO)
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
self.motion_modules: SDXLMotionModel = None
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
self.unet = model_manager.fetch_model("sdxl_unet")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets (TODO)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
# Motion Modules
self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
# Kolors
if self.text_encoder_kolors is not None:
print("Switch to Kolors. The prompter will be replaced.")
self.prompter = KolorsPrompter()
self.prompter.fetch_models(self.text_encoder_kolors)
# The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
else:
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDXLVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents.append(latent.cpu())
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_video(input_frames, **tiler_kwargs)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
latents = latents.to(self.device) # will be deleted for supporting long videos
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
device=self.device,
)
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, timestep=timestep,
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_video(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_video(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
output_frames = self.decode_video(latents, **tiler_kwargs)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return output_frames

View File

@@ -1,209 +0,0 @@
from ..models import ModelManager
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
from ..models.stepvideo_text_encoder import STEP1TextEncoder
from ..models.stepvideo_dit import StepVideoModel
from ..models.stepvideo_vae import StepVideoVAE
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import StepVideoPrompter
import torch
from einops import rearrange
import numpy as np
from PIL import Image
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from transformers.models.bert.modeling_bert import BertEmbeddings
from ..models.stepvideo_dit import RMSNorm
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
class StepVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(sigma_min=0.0, extra_one_step=True, shift=13.0, reverse_sigmas=True, num_train_timesteps=1)
self.prompter = StepVideoPrompter()
self.text_encoder_1: HunyuanDiTCLIPTextEncoder = None
self.text_encoder_2: STEP1TextEncoder = None
self.dit: StepVideoModel = None
self.vae: StepVideoVAE = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae']
def enable_vram_management(self, num_persistent_param_in_dit=None):
dtype = next(iter(self.text_encoder_1.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
BertEmbeddings: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=torch.float32,
computation_device=self.device,
),
)
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
CausalConv: AutoWrappedModule,
CausalConvAfterNorm: AutoWrappedModule,
Upsample2D: AutoWrappedModule,
BaseGroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("stepvideo_text_encoder_2")
self.dit = model_manager.fetch_model("stepvideo_dit")
self.vae = model_manager.fetch_model("stepvideo_vae")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = StepVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
return pipe
def encode_prompt(self, prompt, positive=True):
clip_embeds, llm_embeds, llm_mask = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
clip_embeds = clip_embeds.to(dtype=self.torch_dtype, device=self.device)
llm_embeds = llm_embeds.to(dtype=self.torch_dtype, device=self.device)
llm_mask = llm_mask.to(dtype=self.torch_dtype, device=self.device)
return {"encoder_hidden_states_2": clip_embeds, "encoder_hidden_states": llm_embeds, "encoder_attention_mask": llm_mask}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
denoising_strength=1.0,
seed=None,
rand_device="cpu",
height=544,
width=992,
num_frames=204,
cfg_scale=9.0,
num_inference_steps=30,
tiled=True,
tile_size=(34, 34),
tile_stride=(16, 16),
smooth_scale=0.6,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Initialize noise
latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
# Encode prompts
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Denoise
self.load_models_to_device(["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
# Inference
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae'])
frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])
return frames

View File

@@ -1,300 +0,0 @@
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
from ..schedulers import ContinuousODEScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange, repeat
class SVDVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = ContinuousODEScheduler()
# models
self.image_encoder: SVDImageEncoder = None
self.unet: SVDUNet = None
self.vae_encoder: SVDVAEEncoder = None
self.vae_decoder: SVDVAEDecoder = None
def fetch_models(self, model_manager: ModelManager):
self.image_encoder = model_manager.fetch_model("svd_image_encoder")
self.unet = model_manager.fetch_model("svd_unet")
self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
@staticmethod
def from_model_manager(model_manager: ModelManager, **kwargs):
pipe = SVDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype
)
pipe.fetch_models(model_manager)
return pipe
def encode_image_with_clip(self, image):
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
image = (image + 1.0) / 2.0
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
image = (image - mean) / std
image_emb = self.image_encoder(image)
return image_emb
def encode_image_with_vae(self, image, noise_aug_strength, seed=None):
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype)
image = image + noise_aug_strength * noise
image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
return image_emb
def encode_video_with_vae(self, video):
video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0)
video = rearrange(video, "T C H W -> 1 C T H W")
video = video.to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder.encode_video(video)
latents = rearrange(latents[0], "C T H W -> T C H W")
return latents
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def calculate_noise_pred(
self,
latents,
timestep,
add_time_id,
cfg_scales,
image_emb_vae_posi, image_emb_clip_posi,
image_emb_vae_nega, image_emb_clip_nega
):
# Positive side
noise_pred_posi = self.unet(
torch.cat([latents, image_emb_vae_posi], dim=1),
timestep, image_emb_clip_posi, add_time_id
)
# Negative side
noise_pred_nega = self.unet(
torch.cat([latents, image_emb_vae_nega], dim=1),
timestep, image_emb_clip_nega, add_time_id
)
# Classifier-free guidance
noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
return noise_pred
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
if post_normalize:
mean, std = latents.mean(), latents.std()
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
latents = latents * contrast_enhance_scale
return latents
@torch.no_grad()
def __call__(
self,
input_image=None,
input_video=None,
mask_frames=[],
mask_frame_ids=[],
min_cfg_scale=1.0,
max_cfg_scale=3.0,
denoising_strength=1.0,
num_frames=25,
height=576,
width=1024,
fps=7,
motion_bucket_id=127,
noise_aug_strength=0.02,
num_inference_steps=20,
post_normalize=True,
contrast_enhance_scale=1.2,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Prepare latent tensors
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
if denoising_strength == 1.0:
latents = noise.clone()
else:
latents = self.encode_video_with_vae(input_video)
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
# Prepare mask frames
if len(mask_frames) > 0:
mask_latents = self.encode_video_with_vae(mask_frames)
# Encode image
image_emb_clip_posi = self.encode_image_with_clip(input_image)
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames)
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
# Prepare classifier-free guidance
cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames)
cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype)
# Prepare positional id
add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Mask frames
for frame_id, mask_frame_id in enumerate(mask_frame_ids):
latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
# Fetch model output
noise_pred = self.calculate_noise_pred(
latents, timestep, add_time_id, cfg_scales,
image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega
)
# Forward Euler
latents = self.scheduler.step(noise_pred, timestep, latents)
# Update progress bar
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
video = self.tensor2video(video)
return video
class SVDCLIPImageProcessor:
def __init__(self):
pass
def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True):
h, w = input.shape[-2:]
factors = (h / size[0], w / size[1])
# First, we have to determine sigma
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
sigmas = (
max((factors[0] - 1.0) / 2.0, 0.001),
max((factors[1] - 1.0) / 2.0, 0.001),
)
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
# Make sure it is odd
if (ks[0] % 2) == 0:
ks = ks[0] + 1, ks[1]
if (ks[1] % 2) == 0:
ks = ks[0], ks[1] + 1
input = self._gaussian_blur2d(input, ks, sigmas)
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
return output
def _compute_padding(self, kernel_size):
"""Compute padding tuple."""
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
if len(kernel_size) < 2:
raise AssertionError(kernel_size)
computed = [k - 1 for k in kernel_size]
# for even kernels we need to do asymmetric padding :(
out_padding = 2 * len(kernel_size) * [0]
for i in range(len(kernel_size)):
computed_tmp = computed[-(i + 1)]
pad_front = computed_tmp // 2
pad_rear = computed_tmp - pad_front
out_padding[2 * i + 0] = pad_front
out_padding[2 * i + 1] = pad_rear
return out_padding
def _filter2d(self, input, kernel):
# prepare kernel
b, c, h, w = input.shape
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
height, width = tmp_kernel.shape[-2:]
padding_shape: list[int] = self._compute_padding([height, width])
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
# kernel and input tensor reshape to align element-wise or batch-wise params
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
# convolve the tensor with the kernel.
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
out = output.view(b, c, h, w)
return out
def _gaussian(self, window_size: int, sigma):
if isinstance(sigma, float):
sigma = torch.tensor([[sigma]])
batch_size = sigma.shape[0]
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
if window_size % 2 == 0:
x = x + 0.5
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
return gauss / gauss.sum(-1, keepdim=True)
def _gaussian_blur2d(self, input, kernel_size, sigma):
if isinstance(sigma, tuple):
sigma = torch.tensor([sigma], dtype=input.dtype)
else:
sigma = sigma.to(dtype=input.dtype)
ky, kx = int(kernel_size[0]), int(kernel_size[1])
bs = sigma.shape[0]
kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1))
kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1))
out_x = self._filter2d(input, kernel_x[..., None, :])
out = self._filter2d(out_x, kernel_y[..., None])
return out

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,257 @@
import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoTokenizer
from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Z-Image")
self.text_encoder: ZImageTextEncoder = None
self.dit: ZImageDiT = None
self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
ZImageUnit_ShapeChecker(),
ZImageUnit_PromptEmbedder(),
ZImageUnit_NoiseInitializer(),
ZImageUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_z_image
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("z_image_dit")
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Parameters
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae_decoder'])
image = self.vae_decoder(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class ZImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: ZImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class ZImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = pipe.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
def process(self, pipe: ZImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds}
class ZImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: ZImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class ZImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae_encoder(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_z_image(
dit: ZImageDiT,
latents=None,
timestep=None,
prompt_embeds=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
latents = [rearrange(latents, "B C H W -> C B H W")]
timestep = (1000 - timestep) / 1000
model_output = dit(
latents,
timestep,
prompt_embeds,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0][0]
model_output = -model_output
model_output = rearrange(model_output, "C B H W -> B C H W")
return model_output