update template framework

This commit is contained in:
Artiprocher
2026-04-15 14:07:51 +08:00
parent 9f8c352a15
commit 59b4bbb62c
7 changed files with 85 additions and 24 deletions

View File

@@ -3,7 +3,7 @@ import torch
import numpy as np import numpy as np
from einops import repeat, reduce from einops import repeat, reduce
from typing import Union from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
from ..core.device.npu_compatible_device import get_device_type from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool from ..models.model_loader import ModelPool
@@ -355,6 +355,23 @@ class BasePipeline(torch.nn.Module):
self.template_data_processor = load_template_data_processor(model_config.path)() self.template_data_processor = load_template_data_processor(model_config.path)()
def enable_lora_hot_loading(self, model: torch.nn.Module):
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
return model
module_map = {torch.nn.Linear: AutoWrappedLinear}
vram_config = {
"offload_dtype": self.torch_dtype,
"offload_device": self.device,
"onload_dtype": self.torch_dtype,
"onload_device": self.device,
"preparing_dtype": self.torch_dtype,
"preparing_device": self.device,
"computation_dtype": self.torch_dtype,
"computation_device": self.device,
}
model = enable_vram_management(model, module_map, vram_config=vram_config)
return model
class PipelineUnitGraph: class PipelineUnitGraph:
def __init__(self): def __init__(self):

View File

@@ -3,6 +3,11 @@ import torch
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
if "lora" in inputs:
# Image-to-LoRA models need to load lora here.
pipe.clear_lora(verbose=0)
pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0)
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))

View File

@@ -62,6 +62,7 @@ def add_gradient_config(parser: argparse.ArgumentParser):
def add_template_model_config(parser: argparse.ArgumentParser): def add_template_model_config(parser: argparse.ArgumentParser):
parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.") parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.")
parser.add_argument("--enable_lora_hot_loading", default=False, action="store_true", help="Whether to enable LoRA hot-loading. Only available for image-to-lora models.")
return parser return parser
def add_general_config(parser: argparse.ArgumentParser): def add_general_config(parser: argparse.ArgumentParser):

View File

@@ -2,6 +2,7 @@ import torch, os, importlib, warnings, json, inspect
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from ..core import ModelConfig, load_model from ..core import ModelConfig, load_model
from ..core.device.npu_compatible_device import get_device_type from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora.merge import merge_lora
KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]] KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
@@ -74,9 +75,28 @@ def load_template_data_processor(path):
class TemplatePipeline(torch.nn.Module): class TemplatePipeline(torch.nn.Module):
def __init__(self, models: List[TemplateModel]): def __init__(
self,
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
lazy_loading: bool = False,
):
super().__init__() super().__init__()
self.models = torch.nn.ModuleList(models) self.torch_dtype = torch_dtype
self.device = device
self.model_configs = model_configs
self.lazy_loading = lazy_loading
if lazy_loading:
self.models = None
else:
models = []
for model_config in model_configs:
TemplatePipeline.check_vram_config(model_config)
model_config.download_if_necessary()
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
models.append(model)
self.models = torch.nn.ModuleList(models)
def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache: def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
names = {} names = {}
@@ -100,6 +120,8 @@ class TemplatePipeline(torch.nn.Module):
data = [template_cache[param] for template_cache in template_cache_list if param in template_cache] data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
if param == "kv_cache": if param == "kv_cache":
data = self.merge_kv_cache(data) data = self.merge_kv_cache(data)
elif param == "lora":
data = merge_lora(data)
elif len(data) == 1: elif len(data) == 1:
data = data[0] data = data[0]
else: else:
@@ -125,30 +147,32 @@ class TemplatePipeline(torch.nn.Module):
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(), device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
lazy_loading: bool = False,
): ):
models = [] pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading)
for model_config in model_configs:
TemplatePipeline.check_vram_config(model_config)
model_config.download_if_necessary()
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
models.append(model)
pipe = TemplatePipeline(models)
return pipe return pipe
@torch.no_grad() def fetch_model(self, model_id):
def process_inputs(self, inputs: List[Dict], pipe=None, **kwargs): if self.lazy_loading:
return [(i.get("model_id", 0), self.models[i.get("model_id", 0)].process_inputs(pipe=pipe, **i)) for i in inputs] model_config = self.model_configs[model_id]
model_config.download_if_necessary()
def forward(self, inputs: List[Tuple[int, Dict]], pipe=None, **kwargs): model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
template_cache = [] else:
for model_id, model_inputs in inputs: model = self.models[model_id]
kv_cache = self.models[model_id](pipe=pipe, **model_inputs) return model
template_cache.append(kv_cache)
return template_cache
def call_single_side(self, pipe=None, inputs: List[Dict] = None): def call_single_side(self, pipe=None, inputs: List[Dict] = None):
inputs = self.process_inputs(pipe=pipe, inputs=inputs) model = None
template_cache = self.forward(pipe=pipe, inputs=inputs) onload_model_id = -1
template_cache = []
for i in inputs:
model_id = i.get("model_id", 0)
if model_id != onload_model_id:
model = self.fetch_model(model_id)
onload_model_id = model_id
cache = model.process_inputs(pipe=pipe, **i)
cache = model.forward(pipe=pipe, **cache)
template_cache.append(cache)
template_cache = self.merge_template_cache(template_cache) template_cache = self.merge_template_cache(template_cache)
return template_cache return template_cache

View File

@@ -38,7 +38,7 @@ class GeneralUnit_TemplateProcessInputs(PipelineUnit):
self.data_processor = data_processor self.data_processor = data_processor
def process(self, pipe, template_inputs): def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"): if not hasattr(pipe, "template_model") or template_inputs is None:
return {} return {}
if self.data_processor is not None: if self.data_processor is not None:
template_inputs = self.data_processor(**template_inputs) template_inputs = self.data_processor(**template_inputs)
@@ -58,7 +58,7 @@ class GeneralUnit_TemplateForward(PipelineUnit):
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def process(self, pipe, template_inputs): def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"): if not hasattr(pipe, "template_model") or template_inputs is None:
return {} return {}
template_cache = pipe.template_model.forward( template_cache = pipe.template_model.forward(
**template_inputs, **template_inputs,

View File

@@ -100,6 +100,9 @@ class Flux2ImagePipeline(BasePipeline):
# LoRA # LoRA
lora = None, lora = None,
negative_lora = None, negative_lora = None,
# Text Embedding
extra_text_embedding = None,
negative_extra_text_embedding = None,
# Inpaint # Inpaint
inpaint_mask: Image.Image = None, inpaint_mask: Image.Image = None,
inpaint_blur_size: int = None, inpaint_blur_size: int = None,
@@ -113,10 +116,12 @@ class Flux2ImagePipeline(BasePipeline):
inputs_posi = { inputs_posi = {
"prompt": prompt, "prompt": prompt,
"kv_cache": kv_cache, "kv_cache": kv_cache,
"extra_text_embedding": extra_text_embedding,
} }
inputs_nega = { inputs_nega = {
"negative_prompt": negative_prompt, "negative_prompt": negative_prompt,
"kv_cache": negative_kv_cache, "kv_cache": negative_kv_cache,
"extra_text_embedding": negative_extra_text_embedding,
} }
inputs_shared = { inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
@@ -607,6 +612,7 @@ def model_fn_flux2(
edit_latents=None, edit_latents=None,
edit_image_ids=None, edit_image_ids=None,
kv_cache=None, kv_cache=None,
extra_text_embedding=None,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs, **kwargs,
@@ -617,6 +623,11 @@ def model_fn_flux2(
latents = torch.concat([latents, edit_latents], dim=1) latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1) image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
if extra_text_embedding is not None:
extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device)
extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1])
prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
model_output = dit( model_output = dit(
hidden_states=latents, hidden_states=latents,
timestep=timestep / 1000, timestep=timestep / 1000,

View File

@@ -19,6 +19,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
fp8_models=None, fp8_models=None,
offload_models=None, offload_models=None,
template_model_id_or_path=None, template_model_id_or_path=None,
enable_lora_hot_loading=False,
device="cpu", device="cpu",
task="sft", task="sft",
): ):
@@ -29,6 +30,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.load_training_template_model(self.pipe, template_model_id_or_path, args.use_gradient_checkpointing, args.use_gradient_checkpointing_offload) self.pipe = self.load_training_template_model(self.pipe, template_model_id_or_path, args.use_gradient_checkpointing, args.use_gradient_checkpointing_offload)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
if enable_lora_hot_loading: self.pipe.dit = self.pipe.enable_lora_hot_loading(self.pipe.dit)
# Training mode # Training mode
self.switch_pipe_to_training_mode( self.switch_pipe_to_training_mode(
@@ -129,6 +131,7 @@ if __name__ == "__main__":
fp8_models=args.fp8_models, fp8_models=args.fp8_models,
offload_models=args.offload_models, offload_models=args.offload_models,
template_model_id_or_path=args.template_model_id_or_path, template_model_id_or_path=args.template_model_id_or_path,
enable_lora_hot_loading=args.enable_lora_hot_loading,
task=args.task, task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device, device="cpu" if args.initialize_model_on_cpu else accelerator.device,
) )