Files
DiffSynth-Studio/diffsynth/diffusion/base_pipeline.py
Hong Zhang 681df93a85 Mova (#1337)
* support mova inference

* mova media_io

* add unified audio_video api & fix bug of mono audio input for ltx

* support mova train

* mova docs

* fix bug
2026-03-13 13:06:07 +08:00

469 lines
21 KiB
Python

from PIL import Image
import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
output_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.output_params = output_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def fetch_input_params(self):
params = []
if self.input_params is not None:
for param in self.input_params:
params.append(param)
if self.input_params_posi is not None:
for _, param in self.input_params_posi.items():
params.append(param)
if self.input_params_nega is not None:
for _, param in self.input_params_nega.items():
params.append(param)
params = sorted(list(set(params)))
return params
def fetch_output_params(self):
params = []
if self.output_params is not None:
for param in self.output_params:
params.append(param)
return params
def process(self, pipe, **kwargs) -> dict:
return {}
def post_process(self, pipe, **kwargs) -> dict:
return {}
class BasePipeline(torch.nn.Module):
def __init__(
self,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
self.device_type = parse_device_type(device)
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# VRAM management
self.vram_management_enabled = False
# Pipeline Unit Runner
self.unit_runner = PipelineUnitRunner()
# LoRA Loader
self.lora_loader = GeneralLoRALoader
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
if verbose > 0:
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
if verbose > 0:
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
if verbose > 0:
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def output_audio_format_check(self, audio_output):
# output standard foramt: [C, T], output dtype: float()
# remove batch dim
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
return audio_output.float()
def load_models_to_device(self, model_names):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "offload"):
model.offload()
else:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
getattr(torch, self.device_type).empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "onload"):
model.onload()
else:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def get_vram(self):
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name:
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
if name.isdigit():
return self.get_module(model[int(name)], suffix)
else:
return self.get_module(getattr(model, name), suffix)
else:
return getattr(model, name)
def freeze_except(self, model_names):
self.eval()
self.requires_grad_(False)
for name in model_names:
module = self.get_module(self, name)
if module is None:
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
continue
module.train()
module.requires_grad_(True)
def blend_with_mask(self, base, addition, mask):
return base * (1 - mask) + addition * mask
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
timestep = scheduler.timesteps[progress_id]
if inpaint_mask is not None:
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
def split_pipeline_units(self, model_names: list[str]):
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
def flush_vram_management_device(self, device):
for module in self.modules():
if isinstance(module, AutoTorchModule):
module.offload_device = device
module.onload_device = device
module.preparing_device = device
module.computation_device = device
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=None,
state_dict=None,
verbose=1,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
lora = lora_loader.convert_state_dict(lora)
if hotload is None:
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
if hotload:
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
updated_num = 0
for _, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
name = module.name
lora_a_name = f'{name}.lora_A.weight'
lora_b_name = f'{name}.lora_B.weight'
if lora_a_name in lora and lora_b_name in lora:
updated_num += 1
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
if verbose >= 1:
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
else:
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
def clear_lora(self, verbose=1):
cleared_num = 0
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
if len(module.lora_A_weights) > 0:
cleared_num += 1
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
if verbose >= 1:
print(f"{cleared_num} LoRA layers are cleared.")
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
model_pool = ModelPool()
for model_config in model_configs:
model_config.download_if_necessary()
vram_config = model_config.vram_config()
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
vram_config["computation_device"] = vram_config["computation_device"] or self.device
model_pool.auto_load_model(
model_config.path,
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
)
return model_pool
def check_vram_management_state(self):
vram_management_enabled = False
for module in self.children():
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
vram_management_enabled = True
return vram_management_enabled
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if isinstance(noise_pred_posi, tuple):
# Separately handling different output types of latents, eg. video and audio latents.
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
class PipelineUnitGraph:
def __init__(self):
pass
def build_edges(self, units: list[PipelineUnit]):
# Establish dependencies between units
# to search for subsequent related computation units.
last_compute_unit_id = {}
edges = []
for unit_id, unit in enumerate(units):
for input_param in unit.fetch_input_params():
if input_param in last_compute_unit_id:
edges.append((last_compute_unit_id[input_param], unit_id))
for output_param in unit.fetch_output_params():
last_compute_unit_id[output_param] = unit_id
return edges
def build_chains(self, units: list[PipelineUnit]):
# Establish updating chains for each variable
# to track their computation process.
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
params = sorted(list(set(params)))
chains = {param: [] for param in params}
for unit_id, unit in enumerate(units):
for param in unit.fetch_output_params():
chains[param].append(unit_id)
return chains
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
# Search for units that directly participate in the model's computation.
related_unit_ids = []
for unit_id, unit in enumerate(units):
for model_name in model_names:
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
related_unit_ids.append(unit_id)
break
return related_unit_ids
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
# Search for subsequent related computation units.
related_unit_ids = [unit_id for unit_id in start_unit_ids]
while True:
neighbors = []
for source, target in edges:
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
neighbors.append(target)
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
neighbors.append(source)
neighbors = sorted(list(set(neighbors)))
if len(neighbors) == 0:
break
else:
related_unit_ids.extend(neighbors)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
# If the input parameters of this subgraph are updated outside the subgraph,
# search for the units where these updates occur.
first_compute_unit_id = {}
for unit_id in related_unit_ids:
for param in units[unit_id].fetch_input_params():
if param not in first_compute_unit_id:
first_compute_unit_id[param] = unit_id
updating_unit_ids = []
for param in first_compute_unit_id:
unit_id = first_compute_unit_id[param]
chain = chains[param]
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
for unit_id_ in chain[chain.index(unit_id) + 1:]:
if unit_id_ not in related_unit_ids:
updating_unit_ids.append(unit_id_)
related_unit_ids.extend(updating_unit_ids)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
# Split the computation graph,
# separating all model-related computations.
related_unit_ids = self.search_direct_unit_ids(units, model_names)
edges = self.build_edges(units)
chains = self.build_chains(units)
while True:
num_related_unit_ids = len(related_unit_ids)
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
if len(related_unit_ids) == num_related_unit_ids:
break
else:
num_related_unit_ids = len(related_unit_ids)
related_units = [units[i] for i in related_unit_ids]
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
return related_units, unrelated_units
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega