Merge branch 'main' into wanvideo_seq_usp

This commit is contained in:
mi804
2025-07-30 16:44:44 +08:00
106 changed files with 6696 additions and 697 deletions

View File

@@ -18,12 +18,15 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc
from ..models.step1x_connector import Qwen2Connector
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_value_control import MultiValueEncoder
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
from ..models.nexus_gen import NexusGenAutoregressiveModel
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
@@ -93,9 +96,14 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter_image_encoder = None
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.nexus_gen: NexusGenAutoregressiveModel = None
self.nexus_gen_generation_adapter: NexusGenAdapter = None
self.nexus_gen_editing_adapter: NexusGenImageEmbeddingMerger = None
self.value_controller: MultiValueEncoder = None
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
self.lora_patcher: FluxLoraPatcher = None
self.lora_encoder: FluxLoRAEncoder = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
self.units = [
@@ -110,9 +118,12 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_ControlNet(),
FluxImageUnit_IPAdapter(),
FluxImageUnit_EntityControl(),
FluxImageUnit_NexusGen(),
FluxImageUnit_TeaCache(),
FluxImageUnit_Flex(),
FluxImageUnit_Step1x(),
FluxImageUnit_ValueControl(),
FluxImageUnit_LoRAEncode(),
]
self.model_fn = model_fn_flux_image
@@ -120,18 +131,20 @@ class FluxImagePipeline(BasePipeline):
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str],
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
local_model_path="./models",
skip_download=False
state_dict=None,
):
if isinstance(lora_config, str):
lora_config = ModelConfig(path=lora_config)
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
lora = state_dict
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
lora = loader.convert_state_dict(lora)
if hotload:
for name, module in module.named_modules():
@@ -145,19 +158,21 @@ class FluxImagePipeline(BasePipeline):
loader.load(module, lora, alpha=alpha)
def enable_lora_patcher(self):
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
return
if self.lora_patcher is None:
print("Please load lora patcher models before `enable_lora_patcher()`.")
return
for name, module in self.dit.named_modules():
if isinstance(module, AutoWrappedLinear):
merger_name = name.replace(".", "___")
if merger_name in self.lora_patcher.model_dict:
module.lora_merger = self.lora_patcher.model_dict[merger_name]
def load_loras(
self,
module: torch.nn.Module,
lora_configs: list[Union[ModelConfig, str]],
alpha=1,
hotload=False,
extra_fused_lora=False,
):
for lora_config in lora_configs:
self.load_lora(module, lora_config, hotload=hotload, alpha=alpha)
if extra_fused_lora:
lora_fuser = FluxLoRAFuser(device="cuda", torch_dtype=torch.bfloat16)
fused_lora = lora_fuser(lora_configs)
self.load_lora(module, state_dict=fused_lora, hotload=hotload, alpha=alpha)
def clear_lora(self):
for name, module in self.named_modules():
@@ -182,22 +197,19 @@ class FluxImagePipeline(BasePipeline):
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder_1 is not None:
dtype = next(iter(self.text_encoder_1.parameters())).dtype
def _enable_vram_management_with_default_config(self, model, vram_limit):
if model is not None:
dtype = next(iter(model.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
model,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
LoRALayerBlock: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -209,7 +221,52 @@ class FluxImagePipeline(BasePipeline):
),
vram_limit=vram_limit,
)
def enable_lora_magic(self):
if self.dit is not None:
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=None,
)
if self.lora_patcher is not None:
for name, module in self.dit.named_modules():
if isinstance(module, AutoWrappedLinear):
merger_name = name.replace(".", "___")
if merger_name in self.lora_patcher.model_dict:
module.lora_merger = self.lora_patcher.model_dict[merger_name]
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
# Default config
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector", "lora_encoder"]
for model_name in default_vram_management_models:
self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
# Special config
if self.text_encoder_2 is not None:
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
@@ -258,14 +315,18 @@ class FluxImagePipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.vae_decoder is not None:
dtype = next(iter(self.vae_decoder.parameters())).dtype
if self.ipadapter_image_encoder is not None:
from transformers.models.siglip.modeling_siglip import SiglipVisionEmbeddings, SiglipEncoder, SiglipMultiheadAttentionPoolingHead
dtype = next(iter(self.ipadapter_image_encoder.parameters())).dtype
enable_vram_management(
self.vae_decoder,
self.ipadapter_image_encoder,
module_map = {
SiglipVisionEmbeddings: AutoWrappedModule,
SiglipEncoder: AutoWrappedModule,
SiglipMultiheadAttentionPoolingHead: AutoWrappedModule,
torch.nn.MultiheadAttention: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -277,14 +338,25 @@ class FluxImagePipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.vae_encoder is not None:
dtype = next(iter(self.vae_encoder.parameters())).dtype
if self.qwenvl is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionPatchEmbed, Qwen2_5_VLVisionBlock, Qwen2_5_VLPatchMerger,
Qwen2_5_VLDecoderLayer, Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
)
dtype = next(iter(self.qwenvl.parameters())).dtype
enable_vram_management(
self.vae_encoder,
self.qwenvl,
module_map = {
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
Qwen2_5_VLVisionBlock: AutoWrappedModule,
Qwen2_5_VLPatchMerger: AutoWrappedModule,
Qwen2_5_VLDecoderLayer: AutoWrappedModule,
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -303,16 +375,12 @@ class FluxImagePipeline(BasePipeline):
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
model_config.download_if_necessary()
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
@@ -335,13 +403,29 @@ class FluxImagePipeline(BasePipeline):
if pipe.image_proj_model is not None:
pipe.infinityou_processor = InfinitYou(device=device)
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm")
pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter")
pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter")
if nexus_gen_processor_config is not None and pipe.nexus_gen is not None:
nexus_gen_processor_config.download_if_necessary()
pipe.nexus_gen.load_processor(nexus_gen_processor_config.path)
# ControlNet
controlnets = []
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_controlnet":
controlnets.append(model)
pipe.controlnet = MultiControlNet(controlnets)
if len(controlnets) > 0:
pipe.controlnet = MultiControlNet(controlnets)
# Value Controller
value_controllers = []
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_value_controller":
value_controllers.append(model)
if len(value_controllers) > 0:
pipe.value_controller = MultiValueEncoder(value_controllers)
return pipe
@@ -393,8 +477,15 @@ class FluxImagePipeline(BasePipeline):
flex_control_image: Image.Image = None,
flex_control_strength: float = 0.5,
flex_control_stop: float = 0.5,
# Value Controller
value_controller_inputs: Union[list[float], float] = None,
# Step1x
step1x_reference_image: Image.Image = None,
# NexusGen
nexus_gen_reference_image: Image.Image = None,
# LoRA Encoder
lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
lora_encoder_scale: float = 1.0,
# TeaCache
tea_cache_l1_thresh: float = None,
# Tile
@@ -426,7 +517,10 @@ class FluxImagePipeline(BasePipeline):
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
"infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
"value_controller_inputs": value_controller_inputs,
"step1x_reference_image": step1x_reference_image,
"nexus_gen_reference_image": nexus_gen_reference_image,
"lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
"tea_cache_l1_thresh": tea_cache_l1_thresh,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"progress_bar_cmd": progress_bar_cmd,
@@ -677,15 +771,70 @@ class FluxImageUnit_EntityControl(PipelineUnit):
if eligen_entity_prompts is None or eligen_entity_masks is None:
return inputs_shared, inputs_posi, inputs_nega
pipe.load_models_to_device(self.onload_model_names)
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"])
inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"])
inputs_posi.update(eligen_kwargs_posi)
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
inputs_nega.update(eligen_kwargs_nega)
return inputs_shared, inputs_posi, inputs_nega
class FluxImageUnit_NexusGen(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
)
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
if pipe.nexus_gen is None:
return inputs_shared, inputs_posi, inputs_nega
pipe.load_models_to_device(self.onload_model_names)
if inputs_shared.get("nexus_gen_reference_image", None) is None:
assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set."
embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0)
inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed)
inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype)
else:
assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set."
embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"])
embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long)
ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long)
inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid)
inputs_posi["text_ids"] = self.get_editing_text_ids(
inputs_shared["latents"],
embeds_grid[0][1].item(), embeds_grid[0][2].item(),
ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(),
)
return inputs_shared, inputs_posi, inputs_nega
def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width):
# prepare text ids for target and reference embeddings
batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width
embed_ids = torch.zeros(height // 2, width // 2, 3)
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype)
batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width
ref_embed_ids = torch.zeros(height // 2, width // 2, 3)
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0
ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype)
text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1)
return text_ids
class FluxImageUnit_Step1x(PipelineUnit):
def __init__(self):
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
@@ -704,7 +853,8 @@ class FluxImageUnit_Step1x(PipelineUnit):
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae_encoder(image)
inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
if inputs_shared.get("cfg_scale", 1) != 1:
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
return inputs_shared, inputs_posi, inputs_nega
@@ -723,10 +873,12 @@ class FluxImageUnit_Flex(PipelineUnit):
super().__init__(
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae_encoder",)
)
)
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
if pipe.dit.input_dim == 196:
if flex_control_stop is None:
flex_control_stop = 1
pipe.load_models_to_device(self.onload_model_names)
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
@@ -756,18 +908,53 @@ class FluxImageUnit_Flex(PipelineUnit):
class FluxImageUnit_InfiniteYou(PipelineUnit):
def __init__(self):
super().__init__(input_params=("infinityou_id_image", "infinityou_guidance"))
super().__init__(
input_params=("infinityou_id_image", "infinityou_guidance"),
onload_model_names=("infinityou_processor",)
)
def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):
pipe.load_models_to_device("infinityou_processor")
if infinityou_id_image is not None:
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance)
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)
else:
return {}
class InfinitYou:
class FluxImageUnit_ValueControl(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params=("value_controller_inputs",),
onload_model_names=("value_controller",)
)
def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
return prompt_emb, text_ids
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
if value_controller_inputs is None:
return {}
if not isinstance(value_controller_inputs, list):
value_controller_inputs = [value_controller_inputs]
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
pipe.load_models_to_device(["value_controller"])
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
value_emb = value_emb.unsqueeze(0)
prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
return {"prompt_emb": prompt_emb, "text_ids": text_ids}
class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
self.device = device
@@ -779,7 +966,7 @@ class InfinitYou:
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
self.arcface_model = init_recognition_model('arcface', device=self.device)
self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype)
def _detect_face(self, id_image_cv2):
face_info = self.app_640.get(id_image_cv2)
@@ -791,16 +978,16 @@ class InfinitYou:
face_info = self.app_160.get(id_image_cv2)
return face_info
def extract_arcface_bgr_embedding(self, in_image, landmark):
def extract_arcface_bgr_embedding(self, in_image, landmark, device):
from insightface.utils import face_align
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.contiguous().to(self.device)
arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype)
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
return face_emb
def prepare_infinite_you(self, model, id_image, infinityou_guidance):
def prepare_infinite_you(self, model, id_image, infinityou_guidance, device):
import cv2
if id_image is None:
return {'id_emb': None}
@@ -809,12 +996,72 @@ class InfinitYou:
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device)
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype)
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
class FluxImageUnit_LoRAEncode(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("lora_encoder",)
)
def parse_lora_encoder_inputs(self, lora_encoder_inputs):
if not isinstance(lora_encoder_inputs, list):
lora_encoder_inputs = [lora_encoder_inputs]
lora_configs = []
for lora_encoder_input in lora_encoder_inputs:
if isinstance(lora_encoder_input, str):
lora_encoder_input = ModelConfig(path=lora_encoder_input)
lora_encoder_input.download_if_necessary()
lora_configs.append(lora_encoder_input)
return lora_configs
def load_lora(self, lora_config, dtype, device):
loader = FluxLoRALoader(torch_dtype=dtype, device=device)
lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device)
lora = loader.convert_state_dict(lora)
return lora
def lora_embedding(self, pipe, lora_encoder_inputs):
lora_emb = []
for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs):
lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device)
lora_emb.append(pipe.lora_encoder(lora))
lora_emb = torch.concat(lora_emb, dim=1)
return lora_emb
def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb):
prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1)
extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
return prompt_emb, text_ids
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("lora_encoder_inputs", None) is None:
return inputs_shared, inputs_posi, inputs_nega
# Encode
pipe.load_models_to_device(["lora_encoder"])
lora_encoder_inputs = inputs_shared["lora_encoder_inputs"]
lora_emb = self.lora_embedding(pipe, lora_encoder_inputs)
# Scale
lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None)
if lora_encoder_scale is not None:
lora_emb = lora_emb * lora_encoder_scale
# Add to prompt embedding
inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding(
inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb)
return inputs_shared, inputs_posi, inputs_nega
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
@@ -984,6 +1231,7 @@ def model_fn_flux_image(
hidden_states = dit.x_embedder(hidden_states)
# EliGen
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:

View File

@@ -12,6 +12,7 @@ from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
@@ -26,194 +27,6 @@ from ..lora import GeneralLoRALoader
class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda", torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def load_models_to_device(self, model_names=[]):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
torch.cuda.empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def enable_cpu_offload(self):
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
self.vram_management_enabled = True
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name in model_names:
model.train()
model.requires_grad_(True)
else:
model.eval()
model.requires_grad_(False)
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
if self.path is None:
# Check model_id and origin_file_pattern
if self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
# Skip if not in rank 0
if use_usp:
import torch.distributed as dist
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else:
allow_file_pattern = self.origin_file_pattern
is_folder = False
# Download
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
# Let rank 1, 2, ... wait for rank 0
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
# Return downloaded files
if is_folder:
self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
else:
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
@@ -226,17 +39,21 @@ class WanVideoPipeline(BasePipeline):
self.text_encoder: WanTextEncoder = None
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.dit2: WanModel = None
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
WanVideoUnit_NoiseInitializer(),
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_ImageEmbedderVAE(),
WanVideoUnit_ImageEmbedderCLIP(),
WanVideoUnit_ImageEmbedderFused(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_FunCameraControl(),
@@ -256,7 +73,9 @@ class WanVideoPipeline(BasePipeline):
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
@@ -328,6 +147,37 @@ class WanVideoPipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.dit2 is not None:
dtype = next(iter(self.dit2.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
@@ -426,6 +276,10 @@ class WanVideoPipeline(BasePipeline):
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
if self.dit2 is not None:
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@@ -436,8 +290,6 @@ class WanVideoPipeline(BasePipeline):
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
):
@@ -462,7 +314,7 @@ class WanVideoPipeline(BasePipeline):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp)
model_config.download_if_necessary(use_usp=use_usp)
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
@@ -471,14 +323,23 @@ class WanVideoPipeline(BasePipeline):
# Load models
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
dit = model_manager.fetch_model("wan_video_dit", index=2)
if isinstance(dit, list):
pipe.dit, pipe.dit2 = dit
else:
pipe.dit = dit
pipe.vae = model_manager.fetch_model("wan_video_vae")
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace")
# Size division factor
if pipe.vae is not None:
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
# Initialize tokenizer
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
tokenizer_config.download_if_necessary(use_usp=use_usp)
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
@@ -522,6 +383,8 @@ class WanVideoPipeline(BasePipeline):
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Boundary
switch_DiT_boundary: Optional[float] = 0.875,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
@@ -574,8 +437,14 @@ class WanVideoPipeline(BasePipeline):
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["dit"] = self.dit2
# Timestep
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
@@ -589,6 +458,8 @@ class WanVideoPipeline(BasePipeline):
# Scheduler
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
if "first_frame_latents" in inputs_shared:
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
# VACE (TODO: remove it)
if vace_reference_image is not None:
@@ -604,63 +475,6 @@ class WanVideoPipeline(BasePipeline):
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict:
raise NotImplementedError("`process` is not implemented.")
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "num_frames"))
@@ -679,7 +493,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
length = (num_frames - 1) // 4 + 1
if vace_reference_image is not None:
length += 1
noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device)
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
if vace_reference_image is not None:
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
return {"noise": noise}
@@ -728,6 +543,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
class WanVideoUnit_ImageEmbedder(PipelineUnit):
"""
Deprecated
"""
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
@@ -735,7 +553,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None:
if input_image is None or pipe.image_encoder is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
@@ -763,13 +581,90 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "height", "width"),
onload_model_names=("image_encoder",)
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
if pipe.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context}
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.require_vae_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"y": y}
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
"""
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
"""
def __init__(self):
super().__init__(
input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents[:, :, 0: 1] = z
return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
onload_model_names=("vae")
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
@@ -793,7 +688,7 @@ class WanVideoUnit_FunReference(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("reference_image", "height", "width", "reference_image"),
onload_model_names=("vae")
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
@@ -812,7 +707,8 @@ class WanVideoUnit_FunReference(PipelineUnit):
class WanVideoUnit_FunCameraControl(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image")
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
@@ -835,6 +731,7 @@ class WanVideoUnit_FunCameraControl(PipelineUnit):
input_image = input_image.resize((width, height))
input_latents = pipe.preprocess_video([input_image])
pipe.load_models_to_device(self.onload_model_names)
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
y = torch.zeros_like(latents).to(pipe.device)
y[:, :, :1] = input_latents
@@ -1014,10 +911,14 @@ class TemporalTiler_BCTHW:
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if border_width == 0:
return x
shift = 0.5
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
x[:border_width] = (torch.arange(border_width) + shift) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
@@ -1078,6 +979,7 @@ def model_fn_wan_video(
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1111,9 +1013,20 @@ def model_fn_wan_video(
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
# Timestep
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
timestep = torch.concat([
torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
]).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
# Motion Controller
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
@@ -1124,15 +1037,16 @@ def model_fn_wan_video(
x = torch.concat([x] * context.shape[0], dim=0)
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
if dit.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
# Image Embedding
if y is not None and dit.require_vae_embedding:
x = torch.cat([x, y], dim=1)
if clip_feature is not None and dit.require_clip_embedding:
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
# Add camera control
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
# Reference image
if reference_latents is not None: