update template framework

This commit is contained in:
Artiprocher
2026-04-15 14:07:51 +08:00
parent 9f8c352a15
commit 59b4bbb62c
7 changed files with 85 additions and 24 deletions

View File

@@ -3,7 +3,7 @@ import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
@@ -354,6 +354,23 @@ class BasePipeline(torch.nn.Module):
self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
self.template_data_processor = load_template_data_processor(model_config.path)()
def enable_lora_hot_loading(self, model: torch.nn.Module):
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
return model
module_map = {torch.nn.Linear: AutoWrappedLinear}
vram_config = {
"offload_dtype": self.torch_dtype,
"offload_device": self.device,
"onload_dtype": self.torch_dtype,
"onload_device": self.device,
"preparing_dtype": self.torch_dtype,
"preparing_device": self.device,
"computation_dtype": self.torch_dtype,
"computation_device": self.device,
}
model = enable_vram_management(model, module_map, vram_config=vram_config)
return model
class PipelineUnitGraph:

View File

@@ -3,6 +3,11 @@ import torch
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
if "lora" in inputs:
# Image-to-LoRA models need to load lora here.
pipe.clear_lora(verbose=0)
pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0)
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))

View File

@@ -62,6 +62,7 @@ def add_gradient_config(parser: argparse.ArgumentParser):
def add_template_model_config(parser: argparse.ArgumentParser):
parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.")
parser.add_argument("--enable_lora_hot_loading", default=False, action="store_true", help="Whether to enable LoRA hot-loading. Only available for image-to-lora models.")
return parser
def add_general_config(parser: argparse.ArgumentParser):

View File

@@ -2,6 +2,7 @@ import torch, os, importlib, warnings, json, inspect
from typing import Dict, List, Tuple, Union
from ..core import ModelConfig, load_model
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora.merge import merge_lora
KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
@@ -74,9 +75,28 @@ def load_template_data_processor(path):
class TemplatePipeline(torch.nn.Module):
def __init__(self, models: List[TemplateModel]):
def __init__(
self,
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
lazy_loading: bool = False,
):
super().__init__()
self.models = torch.nn.ModuleList(models)
self.torch_dtype = torch_dtype
self.device = device
self.model_configs = model_configs
self.lazy_loading = lazy_loading
if lazy_loading:
self.models = None
else:
models = []
for model_config in model_configs:
TemplatePipeline.check_vram_config(model_config)
model_config.download_if_necessary()
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
models.append(model)
self.models = torch.nn.ModuleList(models)
def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
names = {}
@@ -100,6 +120,8 @@ class TemplatePipeline(torch.nn.Module):
data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
if param == "kv_cache":
data = self.merge_kv_cache(data)
elif param == "lora":
data = merge_lora(data)
elif len(data) == 1:
data = data[0]
else:
@@ -125,30 +147,32 @@ class TemplatePipeline(torch.nn.Module):
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
lazy_loading: bool = False,
):
models = []
for model_config in model_configs:
TemplatePipeline.check_vram_config(model_config)
model_config.download_if_necessary()
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
models.append(model)
pipe = TemplatePipeline(models)
pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading)
return pipe
@torch.no_grad()
def process_inputs(self, inputs: List[Dict], pipe=None, **kwargs):
return [(i.get("model_id", 0), self.models[i.get("model_id", 0)].process_inputs(pipe=pipe, **i)) for i in inputs]
def forward(self, inputs: List[Tuple[int, Dict]], pipe=None, **kwargs):
template_cache = []
for model_id, model_inputs in inputs:
kv_cache = self.models[model_id](pipe=pipe, **model_inputs)
template_cache.append(kv_cache)
return template_cache
def fetch_model(self, model_id):
if self.lazy_loading:
model_config = self.model_configs[model_id]
model_config.download_if_necessary()
model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
model = self.models[model_id]
return model
def call_single_side(self, pipe=None, inputs: List[Dict] = None):
inputs = self.process_inputs(pipe=pipe, inputs=inputs)
template_cache = self.forward(pipe=pipe, inputs=inputs)
model = None
onload_model_id = -1
template_cache = []
for i in inputs:
model_id = i.get("model_id", 0)
if model_id != onload_model_id:
model = self.fetch_model(model_id)
onload_model_id = model_id
cache = model.process_inputs(pipe=pipe, **i)
cache = model.forward(pipe=pipe, **cache)
template_cache.append(cache)
template_cache = self.merge_template_cache(template_cache)
return template_cache

View File

@@ -38,7 +38,7 @@ class GeneralUnit_TemplateProcessInputs(PipelineUnit):
self.data_processor = data_processor
def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"):
if not hasattr(pipe, "template_model") or template_inputs is None:
return {}
if self.data_processor is not None:
template_inputs = self.data_processor(**template_inputs)
@@ -58,7 +58,7 @@ class GeneralUnit_TemplateForward(PipelineUnit):
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"):
if not hasattr(pipe, "template_model") or template_inputs is None:
return {}
template_cache = pipe.template_model.forward(
**template_inputs,

View File

@@ -100,6 +100,9 @@ class Flux2ImagePipeline(BasePipeline):
# LoRA
lora = None,
negative_lora = None,
# Text Embedding
extra_text_embedding = None,
negative_extra_text_embedding = None,
# Inpaint
inpaint_mask: Image.Image = None,
inpaint_blur_size: int = None,
@@ -113,10 +116,12 @@ class Flux2ImagePipeline(BasePipeline):
inputs_posi = {
"prompt": prompt,
"kv_cache": kv_cache,
"extra_text_embedding": extra_text_embedding,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"kv_cache": negative_kv_cache,
"extra_text_embedding": negative_extra_text_embedding,
}
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
@@ -607,6 +612,7 @@ def model_fn_flux2(
edit_latents=None,
edit_image_ids=None,
kv_cache=None,
extra_text_embedding=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
@@ -617,6 +623,11 @@ def model_fn_flux2(
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
if extra_text_embedding is not None:
extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device)
extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1])
prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
model_output = dit(
hidden_states=latents,
timestep=timestep / 1000,