mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
1115 lines
50 KiB
Python
1115 lines
50 KiB
Python
import torch, warnings, glob, os
|
|
import numpy as np
|
|
from PIL import Image
|
|
from einops import repeat, reduce
|
|
from typing import Optional, Union
|
|
from dataclasses import dataclass
|
|
from modelscope import snapshot_download
|
|
from einops import rearrange
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
|
|
from ..models import ModelManager, load_state_dict
|
|
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
|
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
|
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
|
from ..models.wan_video_vace import VaceWanModel
|
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
|
from ..schedulers.flow_match import FlowMatchScheduler
|
|
from ..prompters import WanPrompter
|
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
|
from ..lora import GeneralLoRALoader
|
|
|
|
|
|
class BasePipeline(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
device="cuda", torch_dtype=torch.float16,
|
|
height_division_factor=64, width_division_factor=64,
|
|
time_division_factor=None, time_division_remainder=None,
|
|
):
|
|
super().__init__()
|
|
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
|
self.device = device
|
|
self.torch_dtype = torch_dtype
|
|
# The following parameters are used for shape check.
|
|
self.height_division_factor = height_division_factor
|
|
self.width_division_factor = width_division_factor
|
|
self.time_division_factor = time_division_factor
|
|
self.time_division_remainder = time_division_remainder
|
|
self.vram_management_enabled = False
|
|
|
|
|
|
def to(self, *args, **kwargs):
|
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
|
if device is not None:
|
|
self.device = device
|
|
if dtype is not None:
|
|
self.torch_dtype = dtype
|
|
super().to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def check_resize_height_width(self, height, width, num_frames=None):
|
|
# Shape check
|
|
if height % self.height_division_factor != 0:
|
|
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
|
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
|
if width % self.width_division_factor != 0:
|
|
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
|
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
|
if num_frames is None:
|
|
return height, width
|
|
else:
|
|
if num_frames % self.time_division_factor != self.time_division_remainder:
|
|
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
|
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
|
return height, width, num_frames
|
|
|
|
|
|
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
|
# Transform a PIL.Image to torch.Tensor
|
|
image = torch.Tensor(np.array(image, dtype=np.float32))
|
|
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
|
image = image * ((max_value - min_value) / 255) + min_value
|
|
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
|
return image
|
|
|
|
|
|
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
|
# Transform a list of PIL.Image to torch.Tensor
|
|
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
|
video = torch.stack(video, dim=pattern.index("T") // 2)
|
|
return video
|
|
|
|
|
|
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
|
# Transform a torch.Tensor to PIL.Image
|
|
if pattern != "H W C":
|
|
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
|
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
|
image = image.to(device="cpu", dtype=torch.uint8)
|
|
image = Image.fromarray(image.numpy())
|
|
return image
|
|
|
|
|
|
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
|
# Transform a torch.Tensor to list of PIL.Image
|
|
if pattern != "T H W C":
|
|
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
|
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
|
return video
|
|
|
|
|
|
def load_models_to_device(self, model_names=[]):
|
|
if self.vram_management_enabled:
|
|
# offload models
|
|
for name, model in self.named_children():
|
|
if name not in model_names:
|
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
|
for module in model.modules():
|
|
if hasattr(module, "offload"):
|
|
module.offload()
|
|
else:
|
|
model.cpu()
|
|
torch.cuda.empty_cache()
|
|
# onload models
|
|
for name, model in self.named_children():
|
|
if name in model_names:
|
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
|
for module in model.modules():
|
|
if hasattr(module, "onload"):
|
|
module.onload()
|
|
else:
|
|
model.to(self.device)
|
|
|
|
|
|
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
|
# Initialize Gaussian noise
|
|
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
|
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
|
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
|
return noise
|
|
|
|
|
|
def enable_cpu_offload(self):
|
|
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
|
self.vram_management_enabled = True
|
|
|
|
|
|
def get_free_vram(self):
|
|
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
allocated_memory = torch.cuda.device_memory_used(self.device)
|
|
return (total_memory - allocated_memory) / (1024 ** 3)
|
|
|
|
|
|
def freeze_except(self, model_names):
|
|
for name, model in self.named_children():
|
|
if name in model_names:
|
|
model.train()
|
|
model.requires_grad_(True)
|
|
else:
|
|
model.eval()
|
|
model.requires_grad_(False)
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
path: Union[str, list[str]] = None
|
|
model_id: str = None
|
|
origin_file_pattern: Union[str, list[str]] = None
|
|
download_resource: str = "ModelScope"
|
|
offload_device: Optional[Union[str, torch.device]] = None
|
|
offload_dtype: Optional[torch.dtype] = None
|
|
|
|
def download_if_necessary(self, local_model_path="./models", skip_download=False):
|
|
if self.path is None:
|
|
if self.model_id is None or self.origin_file_pattern is None:
|
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
|
if not skip_download:
|
|
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
|
snapshot_download(
|
|
self.model_id,
|
|
local_dir=os.path.join(local_model_path, self.model_id),
|
|
allow_file_pattern=self.origin_file_pattern,
|
|
ignore_file_pattern=downloaded_files,
|
|
local_files_only=False
|
|
)
|
|
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
|
if isinstance(self.path, list) and len(self.path) == 1:
|
|
self.path = self.path[0]
|
|
|
|
|
|
class WanVideoPipeline(BasePipeline):
|
|
|
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
|
|
super().__init__(
|
|
device=device, torch_dtype=torch_dtype,
|
|
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
|
)
|
|
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
|
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
|
self.text_encoder: WanTextEncoder = None
|
|
self.image_encoder: WanImageEncoder = None
|
|
self.dit: WanModel = None
|
|
self.vae: WanVideoVAE = None
|
|
self.motion_controller: WanMotionControllerModel = None
|
|
self.vace: VaceWanModel = None
|
|
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
|
self.unit_runner = PipelineUnitRunner()
|
|
self.units = [
|
|
WanVideoUnit_ShapeChecker(),
|
|
WanVideoUnit_NoiseInitializer(),
|
|
WanVideoUnit_InputVideoEmbedder(),
|
|
WanVideoUnit_PromptEmbedder(),
|
|
WanVideoUnit_ImageEmbedder(),
|
|
WanVideoUnit_FunCamera(),
|
|
WanVideoUnit_FunControl(),
|
|
WanVideoUnit_FunReference(),
|
|
WanVideoUnit_SpeedControl(),
|
|
WanVideoUnit_VACE(),
|
|
WanVideoUnit_TeaCache(),
|
|
WanVideoUnit_CfgMerger(),
|
|
]
|
|
self.model_fn = model_fn_wan_video
|
|
|
|
|
|
def load_lora(self, module, path, alpha=1):
|
|
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
|
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
|
loader.load(module, lora, alpha=alpha)
|
|
|
|
|
|
def training_loss(self, **inputs):
|
|
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
|
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
|
|
|
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
|
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
|
|
|
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
|
|
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
|
loss = loss * self.scheduler.training_weight(timestep)
|
|
return loss
|
|
|
|
|
|
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
|
self.vram_management_enabled = True
|
|
if num_persistent_param_in_dit is not None:
|
|
vram_limit = None
|
|
else:
|
|
if vram_limit is None:
|
|
vram_limit = self.get_free_vram()
|
|
vram_limit = vram_limit - vram_buffer
|
|
if self.text_encoder is not None:
|
|
dtype = next(iter(self.text_encoder.parameters())).dtype
|
|
enable_vram_management(
|
|
self.text_encoder,
|
|
module_map = {
|
|
torch.nn.Linear: AutoWrappedLinear,
|
|
torch.nn.Embedding: AutoWrappedModule,
|
|
T5RelativeEmbedding: AutoWrappedModule,
|
|
T5LayerNorm: AutoWrappedModule,
|
|
},
|
|
module_config = dict(
|
|
offload_dtype=dtype,
|
|
offload_device="cpu",
|
|
onload_dtype=dtype,
|
|
onload_device="cpu",
|
|
computation_dtype=self.torch_dtype,
|
|
computation_device=self.device,
|
|
),
|
|
vram_limit=vram_limit,
|
|
)
|
|
if self.dit is not None:
|
|
dtype = next(iter(self.dit.parameters())).dtype
|
|
device = "cpu" if vram_limit is not None else self.device
|
|
enable_vram_management(
|
|
self.dit,
|
|
module_map = {
|
|
torch.nn.Linear: AutoWrappedLinear,
|
|
torch.nn.Conv3d: AutoWrappedModule,
|
|
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
|
RMSNorm: AutoWrappedModule,
|
|
torch.nn.Conv2d: AutoWrappedModule,
|
|
},
|
|
module_config = dict(
|
|
offload_dtype=dtype,
|
|
offload_device="cpu",
|
|
onload_dtype=dtype,
|
|
onload_device=device,
|
|
computation_dtype=self.torch_dtype,
|
|
computation_device=self.device,
|
|
),
|
|
max_num_param=num_persistent_param_in_dit,
|
|
overflow_module_config = dict(
|
|
offload_dtype=dtype,
|
|
offload_device="cpu",
|
|
onload_dtype=dtype,
|
|
onload_device="cpu",
|
|
computation_dtype=self.torch_dtype,
|
|
computation_device=self.device,
|
|
),
|
|
vram_limit=vram_limit,
|
|
)
|
|
if self.vae is not None:
|
|
dtype = next(iter(self.vae.parameters())).dtype
|
|
enable_vram_management(
|
|
self.vae,
|
|
module_map = {
|
|
torch.nn.Linear: AutoWrappedLinear,
|
|
torch.nn.Conv2d: AutoWrappedModule,
|
|
RMS_norm: AutoWrappedModule,
|
|
CausalConv3d: AutoWrappedModule,
|
|
Upsample: AutoWrappedModule,
|
|
torch.nn.SiLU: AutoWrappedModule,
|
|
torch.nn.Dropout: AutoWrappedModule,
|
|
},
|
|
module_config = dict(
|
|
offload_dtype=dtype,
|
|
offload_device="cpu",
|
|
onload_dtype=dtype,
|
|
onload_device=self.device,
|
|
computation_dtype=self.torch_dtype,
|
|
computation_device=self.device,
|
|
),
|
|
)
|
|
if self.image_encoder is not None:
|
|
dtype = next(iter(self.image_encoder.parameters())).dtype
|
|
enable_vram_management(
|
|
self.image_encoder,
|
|
module_map = {
|
|
torch.nn.Linear: AutoWrappedLinear,
|
|
torch.nn.Conv2d: AutoWrappedModule,
|
|
torch.nn.LayerNorm: AutoWrappedModule,
|
|
},
|
|
module_config = dict(
|
|
offload_dtype=dtype,
|
|
offload_device="cpu",
|
|
onload_dtype=dtype,
|
|
onload_device="cpu",
|
|
computation_dtype=dtype,
|
|
computation_device=self.device,
|
|
),
|
|
)
|
|
if self.motion_controller is not None:
|
|
dtype = next(iter(self.motion_controller.parameters())).dtype
|
|
enable_vram_management(
|
|
self.motion_controller,
|
|
module_map = {
|
|
torch.nn.Linear: AutoWrappedLinear,
|
|
},
|
|
module_config = dict(
|
|
offload_dtype=dtype,
|
|
offload_device="cpu",
|
|
onload_dtype=dtype,
|
|
onload_device="cpu",
|
|
computation_dtype=dtype,
|
|
computation_device=self.device,
|
|
),
|
|
)
|
|
if self.vace is not None:
|
|
device = "cpu" if vram_limit is not None else self.device
|
|
enable_vram_management(
|
|
self.vace,
|
|
module_map = {
|
|
torch.nn.Linear: AutoWrappedLinear,
|
|
torch.nn.Conv3d: AutoWrappedModule,
|
|
torch.nn.LayerNorm: AutoWrappedModule,
|
|
RMSNorm: AutoWrappedModule,
|
|
},
|
|
module_config = dict(
|
|
offload_dtype=dtype,
|
|
offload_device="cpu",
|
|
onload_dtype=dtype,
|
|
onload_device=device,
|
|
computation_dtype=self.torch_dtype,
|
|
computation_device=self.device,
|
|
),
|
|
vram_limit=vram_limit,
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
def from_pretrained(
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
device: Union[str, torch.device] = "cuda",
|
|
model_configs: list[ModelConfig] = [],
|
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
|
local_model_path: str = "./models",
|
|
skip_download: bool = False,
|
|
redirect_common_files: bool = True,
|
|
):
|
|
# Redirect model path
|
|
if redirect_common_files:
|
|
redirect_dict = {
|
|
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
|
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
|
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
|
|
}
|
|
for model_config in model_configs:
|
|
if model_config.origin_file_pattern is None or model_config.model_id is None:
|
|
continue
|
|
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
|
|
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
|
|
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
|
|
|
|
# Download and load models
|
|
model_manager = ModelManager()
|
|
for model_config in model_configs:
|
|
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
|
model_manager.load_model(
|
|
model_config.path,
|
|
device=model_config.offload_device or device,
|
|
torch_dtype=model_config.offload_dtype or torch_dtype
|
|
)
|
|
|
|
# Initialize pipeline
|
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
|
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
|
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
|
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
|
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
|
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
|
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
|
|
|
# Initialize tokenizer
|
|
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
|
pipe.prompter.fetch_models(pipe.text_encoder)
|
|
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
|
return pipe
|
|
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
# Prompt
|
|
prompt: str,
|
|
negative_prompt: Optional[str] = "",
|
|
# Image-to-video
|
|
input_image: Optional[Image.Image] = None,
|
|
# First-last-frame-to-video
|
|
end_image: Optional[Image.Image] = None,
|
|
# Video-to-video
|
|
input_video: Optional[list[Image.Image]] = None,
|
|
denoising_strength: Optional[float] = 1.0,
|
|
# ControlNet
|
|
control_video: Optional[list[Image.Image]] = None,
|
|
reference_image: Optional[Image.Image] = None,
|
|
# VACE
|
|
vace_video: Optional[list[Image.Image]] = None,
|
|
vace_video_mask: Optional[Image.Image] = None,
|
|
vace_reference_image: Optional[Image.Image] = None,
|
|
vace_scale: Optional[float] = 1.0,
|
|
# Randomness
|
|
seed: Optional[int] = None,
|
|
rand_device: Optional[str] = "cpu",
|
|
# Shape
|
|
height: Optional[int] = 480,
|
|
width: Optional[int] = 832,
|
|
num_frames=81,
|
|
# Classifier-free guidance
|
|
cfg_scale: Optional[float] = 5.0,
|
|
cfg_merge: Optional[bool] = False,
|
|
# Scheduler
|
|
num_inference_steps: Optional[int] = 50,
|
|
sigma_shift: Optional[float] = 5.0,
|
|
# Speed control
|
|
motion_bucket_id: Optional[int] = None,
|
|
# VAE tiling
|
|
tiled: Optional[bool] = True,
|
|
tile_size: Optional[tuple[int, int]] = (30, 52),
|
|
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
|
# Sliding window
|
|
sliding_window_size: Optional[int] = None,
|
|
sliding_window_stride: Optional[int] = None,
|
|
# Teacache
|
|
tea_cache_l1_thresh: Optional[float] = None,
|
|
tea_cache_model_id: Optional[str] = "",
|
|
# progress_bar
|
|
progress_bar_cmd=tqdm,
|
|
# Camera control
|
|
control_camera_video: Optional[torch.Tensor] = None
|
|
):
|
|
# Scheduler
|
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
|
|
|
# Inputs
|
|
inputs_posi = {
|
|
"prompt": prompt,
|
|
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
|
|
}
|
|
inputs_nega = {
|
|
"negative_prompt": negative_prompt,
|
|
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
|
|
}
|
|
inputs_shared = {
|
|
"input_image": input_image,
|
|
"end_image": end_image,
|
|
"input_video": input_video, "denoising_strength": denoising_strength,
|
|
"control_video": control_video, "reference_image": reference_image,
|
|
"control_camera_video": control_camera_video,
|
|
"vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
|
|
"seed": seed, "rand_device": rand_device,
|
|
"height": height, "width": width, "num_frames": num_frames,
|
|
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
|
"num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift,
|
|
"motion_bucket_id": motion_bucket_id,
|
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
|
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
|
}
|
|
for unit in self.units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
|
|
|
# Denoise
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
|
|
# Inference
|
|
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
|
if cfg_scale != 1.0:
|
|
if cfg_merge:
|
|
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
|
|
else:
|
|
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
|
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
else:
|
|
noise_pred = noise_pred_posi
|
|
|
|
# Scheduler
|
|
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
|
|
|
# VACE (TODO: remove it)
|
|
if vace_reference_image is not None:
|
|
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
|
|
|
# Decode
|
|
self.load_models_to_device(['vae'])
|
|
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
video = self.vae_output_to_video(video)
|
|
self.load_models_to_device([])
|
|
|
|
return video
|
|
|
|
|
|
|
|
class PipelineUnit:
|
|
def __init__(
|
|
self,
|
|
seperate_cfg: bool = False,
|
|
take_over: bool = False,
|
|
input_params: tuple[str] = None,
|
|
input_params_posi: dict[str, str] = None,
|
|
input_params_nega: dict[str, str] = None,
|
|
onload_model_names: tuple[str] = None
|
|
):
|
|
self.seperate_cfg = seperate_cfg
|
|
self.take_over = take_over
|
|
self.input_params = input_params
|
|
self.input_params_posi = input_params_posi
|
|
self.input_params_nega = input_params_nega
|
|
self.onload_model_names = onload_model_names
|
|
|
|
|
|
def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict:
|
|
raise NotImplementedError("`process` is not implemented.")
|
|
|
|
|
|
|
|
class PipelineUnitRunner:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
|
if unit.take_over:
|
|
# Let the pipeline unit take over this function.
|
|
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
|
elif unit.seperate_cfg:
|
|
# Positive side
|
|
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
|
inputs_posi.update(processor_outputs)
|
|
# Negative side
|
|
if inputs_shared["cfg_scale"] != 1:
|
|
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
|
inputs_nega.update(processor_outputs)
|
|
else:
|
|
inputs_nega.update(processor_outputs)
|
|
else:
|
|
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
|
inputs_shared.update(processor_outputs)
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
|
|
|
|
|
|
class WanVideoUnit_ShapeChecker(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(input_params=("height", "width", "num_frames"))
|
|
|
|
def process(self, pipe: WanVideoPipeline, height, width, num_frames):
|
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
|
return {"height": height, "width": width, "num_frames": num_frames}
|
|
|
|
|
|
|
|
class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"))
|
|
|
|
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
|
|
length = (num_frames - 1) // 4 + 1
|
|
if vace_reference_image is not None:
|
|
length += 1
|
|
noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device)
|
|
if vace_reference_image is not None:
|
|
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
|
return {"noise": noise}
|
|
|
|
|
|
|
|
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
|
|
onload_model_names=("vae",)
|
|
)
|
|
|
|
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
|
|
if input_video is None:
|
|
return {"latents": noise}
|
|
pipe.load_models_to_device(["vae"])
|
|
input_video = pipe.preprocess_video(input_video)
|
|
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
if pipe.scheduler.training:
|
|
return {"latents": noise, "input_latents": input_latents}
|
|
else:
|
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
|
return {"latents": latents}
|
|
|
|
|
|
|
|
class WanVideoUnit_PromptEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
|
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
|
|
onload_model_names=("text_encoder",)
|
|
)
|
|
|
|
def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device)
|
|
return {"context": prompt_emb}
|
|
|
|
|
|
|
|
class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "control_camera_video","latents"),
|
|
onload_model_names=("image_encoder", "vae")
|
|
)
|
|
|
|
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride, control_camera_video,latents):
|
|
if input_image is None:
|
|
return {}
|
|
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
|
|
|
clip_context = pipe.image_encoder.encode_image([image])
|
|
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
|
msk[:, 1:] = 0
|
|
if end_image is not None:
|
|
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
|
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
|
if pipe.dit.has_image_pos_emb:
|
|
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
|
|
msk[:, -1:] = 1
|
|
else:
|
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
|
|
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
|
msk = msk.transpose(1, 2)[0]
|
|
|
|
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
y = y.unsqueeze(0)
|
|
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
return {"clip_feature": clip_context, "y": y}
|
|
|
|
|
|
|
|
class WanVideoUnit_FunControl(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
|
|
onload_model_names=("vae")
|
|
)
|
|
|
|
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
|
|
if control_video is None:
|
|
return {}
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
control_video = pipe.preprocess_video(control_video)
|
|
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
if clip_feature is None or y is None:
|
|
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
|
|
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
|
|
else:
|
|
y = y[:, -16:]
|
|
y = torch.concat([control_latents, y], dim=1)
|
|
return {"clip_feature": clip_feature, "y": y}
|
|
|
|
|
|
|
|
class WanVideoUnit_FunReference(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("reference_image", "height", "width", "reference_image"),
|
|
onload_model_names=("vae")
|
|
)
|
|
|
|
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
|
|
if reference_image is None:
|
|
return {}
|
|
pipe.load_models_to_device(["vae"])
|
|
reference_image = reference_image.resize((width, height))
|
|
reference_latents = pipe.preprocess_video([reference_image])
|
|
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
|
clip_feature = pipe.preprocess_image(reference_image)
|
|
clip_feature = pipe.image_encoder.encode_image([clip_feature])
|
|
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
|
|
|
|
class WanVideoUnit_FunCamera(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("control_camera_video", "cfg_merge", "num_frames", "height", "width", "input_image", "latents"),
|
|
onload_model_names=("vae")
|
|
)
|
|
|
|
def process(self, pipe: WanVideoPipeline, control_camera_video, cfg_merge, num_frames, height, width, input_image, latents):
|
|
if control_camera_video is None:
|
|
return {}
|
|
control_camera_video = control_camera_video[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
|
|
control_camera_latents = torch.concat(
|
|
[
|
|
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
|
control_camera_video[:, :, 1:]
|
|
], dim=2
|
|
).transpose(1, 2)
|
|
b, f, c, h, w = control_camera_latents.shape
|
|
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
|
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
|
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
|
|
|
|
input_image = input_image.resize((width, height))
|
|
input_latents = pipe.preprocess_video([input_image])
|
|
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
|
y = torch.zeros_like(latents).to(pipe.device)
|
|
if latents.size()[2] != 1:
|
|
y[:, :, :1] = input_latents
|
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
return {"control_camera_latents": control_camera_latents, "control_camera_latents_input": control_camera_latents_input, "y":y}
|
|
|
|
|
|
class WanVideoUnit_SpeedControl(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(input_params=("motion_bucket_id",))
|
|
|
|
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
|
|
if motion_bucket_id is None:
|
|
return {}
|
|
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
return {"motion_bucket_id": motion_bucket_id}
|
|
|
|
|
|
|
|
class WanVideoUnit_VACE(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
|
|
onload_model_names=("vae",)
|
|
)
|
|
|
|
def process(
|
|
self,
|
|
pipe: WanVideoPipeline,
|
|
vace_video, vace_mask, vace_reference_image, vace_scale,
|
|
height, width, num_frames,
|
|
tiled, tile_size, tile_stride
|
|
):
|
|
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
|
|
pipe.load_models_to_device(["vae"])
|
|
if vace_video is None:
|
|
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
|
else:
|
|
vace_video = pipe.preprocess_video(vace_video)
|
|
|
|
if vace_mask is None:
|
|
vace_mask = torch.ones_like(vace_video)
|
|
else:
|
|
vace_mask = pipe.preprocess_video(vace_mask)
|
|
|
|
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
|
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
|
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
|
|
|
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
|
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
|
|
|
if vace_reference_image is None:
|
|
pass
|
|
else:
|
|
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
|
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
|
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
|
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
|
|
|
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
|
return {"vace_context": vace_context, "vace_scale": vace_scale}
|
|
else:
|
|
return {"vace_context": None, "vace_scale": vace_scale}
|
|
|
|
|
|
|
|
class WanVideoUnit_TeaCache(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
|
|
input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
|
|
)
|
|
|
|
def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id):
|
|
if tea_cache_l1_thresh is None:
|
|
return {}
|
|
return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)}
|
|
|
|
|
|
|
|
class WanVideoUnit_CfgMerger(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(take_over=True)
|
|
self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
|
|
|
|
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
|
if not inputs_shared["cfg_merge"]:
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
for name in self.concat_tensor_names:
|
|
tensor_posi = inputs_posi.get(name)
|
|
tensor_nega = inputs_nega.get(name)
|
|
tensor_shared = inputs_shared.get(name)
|
|
if tensor_posi is not None and tensor_nega is not None:
|
|
inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
|
|
elif tensor_shared is not None:
|
|
inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0)
|
|
inputs_posi.clear()
|
|
inputs_nega.clear()
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
|
|
|
|
|
|
class TeaCache:
|
|
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
|
self.num_inference_steps = num_inference_steps
|
|
self.step = 0
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = None
|
|
self.rel_l1_thresh = rel_l1_thresh
|
|
self.previous_residual = None
|
|
self.previous_hidden_states = None
|
|
|
|
self.coefficients_dict = {
|
|
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
|
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
|
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
|
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
|
}
|
|
if model_id not in self.coefficients_dict:
|
|
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
|
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
|
self.coefficients = self.coefficients_dict[model_id]
|
|
|
|
def check(self, dit: WanModel, x, t_mod):
|
|
modulated_inp = t_mod.clone()
|
|
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
coefficients = self.coefficients
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = modulated_inp
|
|
self.step += 1
|
|
if self.step == self.num_inference_steps:
|
|
self.step = 0
|
|
if should_calc:
|
|
self.previous_hidden_states = x.clone()
|
|
return not should_calc
|
|
|
|
def store(self, hidden_states):
|
|
self.previous_residual = hidden_states - self.previous_hidden_states
|
|
self.previous_hidden_states = None
|
|
|
|
def update(self, hidden_states):
|
|
hidden_states = hidden_states + self.previous_residual
|
|
return hidden_states
|
|
|
|
|
|
|
|
class TemporalTiler_BCTHW:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
|
x = torch.ones((length,))
|
|
if not left_bound:
|
|
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
|
if not right_bound:
|
|
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
|
return x
|
|
|
|
def build_mask(self, data, is_bound, border_width):
|
|
_, _, T, _, _ = data.shape
|
|
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
|
mask = repeat(t, "T -> 1 1 T 1 1")
|
|
return mask
|
|
|
|
def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None):
|
|
tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]
|
|
tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}
|
|
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
|
|
if batch_size is not None:
|
|
B *= batch_size
|
|
data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype
|
|
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
|
|
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
|
|
for t in range(0, T, sliding_window_stride):
|
|
if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T:
|
|
continue
|
|
t_ = min(t + sliding_window_size, T)
|
|
model_kwargs.update({
|
|
tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \
|
|
for tensor_name in tensor_names
|
|
})
|
|
model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype)
|
|
mask = self.build_mask(
|
|
model_output,
|
|
is_bound=(t == 0, t_ == T),
|
|
border_width=(sliding_window_size - sliding_window_stride,)
|
|
).to(device=data_device, dtype=data_dtype)
|
|
value[:, :, t: t_, :, :] += model_output * mask
|
|
weight[:, :, t: t_, :, :] += mask
|
|
value /= weight
|
|
model_kwargs.update(tensor_dict)
|
|
return value
|
|
|
|
|
|
|
|
def model_fn_wan_video(
|
|
dit: WanModel,
|
|
motion_controller: WanMotionControllerModel = None,
|
|
vace: VaceWanModel = None,
|
|
latents: torch.Tensor = None,
|
|
timestep: torch.Tensor = None,
|
|
context: torch.Tensor = None,
|
|
clip_feature: Optional[torch.Tensor] = None,
|
|
y: Optional[torch.Tensor] = None,
|
|
reference_latents = None,
|
|
vace_context = None,
|
|
vace_scale = 1.0,
|
|
tea_cache: TeaCache = None,
|
|
use_unified_sequence_parallel: bool = False,
|
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
|
sliding_window_size: Optional[int] = None,
|
|
sliding_window_stride: Optional[int] = None,
|
|
cfg_merge: bool = False,
|
|
use_gradient_checkpointing: bool = False,
|
|
use_gradient_checkpointing_offload: bool = False,
|
|
control_camera_latents = None,
|
|
control_camera_latents_input = None,
|
|
**kwargs,
|
|
):
|
|
if sliding_window_size is not None and sliding_window_stride is not None:
|
|
model_kwargs = dict(
|
|
dit=dit,
|
|
motion_controller=motion_controller,
|
|
vace=vace,
|
|
latents=latents,
|
|
timestep=timestep,
|
|
context=context,
|
|
clip_feature=clip_feature,
|
|
y=y,
|
|
reference_latents=reference_latents,
|
|
vace_context=vace_context,
|
|
vace_scale=vace_scale,
|
|
tea_cache=tea_cache,
|
|
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
|
motion_bucket_id=motion_bucket_id,
|
|
)
|
|
return TemporalTiler_BCTHW().run(
|
|
model_fn_wan_video,
|
|
sliding_window_size, sliding_window_stride,
|
|
latents.device, latents.dtype,
|
|
model_kwargs=model_kwargs,
|
|
tensor_names=["latents", "y"],
|
|
batch_size=2 if cfg_merge else 1
|
|
)
|
|
|
|
if use_unified_sequence_parallel:
|
|
import torch.distributed as dist
|
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|
get_sequence_parallel_world_size,
|
|
get_sp_group)
|
|
|
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
|
if motion_bucket_id is not None and motion_controller is not None:
|
|
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
|
context = dit.text_embedding(context)
|
|
|
|
x = latents
|
|
# Merged cfg
|
|
if x.shape[0] != context.shape[0]:
|
|
x = torch.concat([x] * context.shape[0], dim=0)
|
|
if timestep.shape[0] != context.shape[0]:
|
|
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
|
|
|
if dit.has_image_input:
|
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
|
clip_embdding = dit.img_emb(clip_feature)
|
|
context = torch.cat([clip_embdding, context], dim=1)
|
|
|
|
# Add camera control
|
|
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
|
|
|
# Reference image
|
|
if reference_latents is not None:
|
|
if len(reference_latents.shape) == 5:
|
|
reference_latents = reference_latents[:, :, 0]
|
|
reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
|
|
x = torch.concat([reference_latents, x], dim=1)
|
|
f += 1
|
|
|
|
freqs = torch.cat([
|
|
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
|
|
# TeaCache
|
|
if tea_cache is not None:
|
|
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
|
else:
|
|
tea_cache_update = False
|
|
|
|
if vace_context is not None:
|
|
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
|
|
|
# blocks
|
|
if use_unified_sequence_parallel:
|
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
if tea_cache_update:
|
|
x = tea_cache.update(x)
|
|
else:
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
return custom_forward
|
|
|
|
for block_id, block in enumerate(dit.blocks):
|
|
if use_gradient_checkpointing_offload:
|
|
with torch.autograd.graph.save_on_cpu():
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
x, context, t_mod, freqs,
|
|
use_reentrant=False,
|
|
)
|
|
elif use_gradient_checkpointing:
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
x, context, t_mod, freqs,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
x = block(x, context, t_mod, freqs)
|
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
|
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
|
if tea_cache is not None:
|
|
tea_cache.store(x)
|
|
|
|
if reference_latents is not None:
|
|
x = x[:, reference_latents.shape[1]:]
|
|
f -= 1
|
|
|
|
x = dit.head(x, t)
|
|
if use_unified_sequence_parallel:
|
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
x = get_sp_group().all_gather(x, dim=1)
|
|
x = dit.unpatchify(x, (f, h, w))
|
|
return x
|