mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-15 14:48:23 +00:00
update template framework
This commit is contained in:
@@ -3,7 +3,7 @@ import torch
|
||||
import numpy as np
|
||||
from einops import repeat, reduce
|
||||
from typing import Union
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
@@ -354,6 +354,23 @@ class BasePipeline(torch.nn.Module):
|
||||
self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
self.template_data_processor = load_template_data_processor(model_config.path)()
|
||||
|
||||
|
||||
def enable_lora_hot_loading(self, model: torch.nn.Module):
|
||||
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
|
||||
return model
|
||||
module_map = {torch.nn.Linear: AutoWrappedLinear}
|
||||
vram_config = {
|
||||
"offload_dtype": self.torch_dtype,
|
||||
"offload_device": self.device,
|
||||
"onload_dtype": self.torch_dtype,
|
||||
"onload_device": self.device,
|
||||
"preparing_dtype": self.torch_dtype,
|
||||
"preparing_device": self.device,
|
||||
"computation_dtype": self.torch_dtype,
|
||||
"computation_device": self.device,
|
||||
}
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config)
|
||||
return model
|
||||
|
||||
|
||||
class PipelineUnitGraph:
|
||||
|
||||
@@ -3,6 +3,11 @@ import torch
|
||||
|
||||
|
||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
if "lora" in inputs:
|
||||
# Image-to-LoRA models need to load lora here.
|
||||
pipe.clear_lora(verbose=0)
|
||||
pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0)
|
||||
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ def add_gradient_config(parser: argparse.ArgumentParser):
|
||||
|
||||
def add_template_model_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.")
|
||||
parser.add_argument("--enable_lora_hot_loading", default=False, action="store_true", help="Whether to enable LoRA hot-loading. Only available for image-to-lora models.")
|
||||
return parser
|
||||
|
||||
def add_general_config(parser: argparse.ArgumentParser):
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch, os, importlib, warnings, json, inspect
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from ..core import ModelConfig, load_model
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..utils.lora.merge import merge_lora
|
||||
|
||||
|
||||
KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
|
||||
@@ -74,9 +75,28 @@ def load_template_data_processor(path):
|
||||
|
||||
|
||||
class TemplatePipeline(torch.nn.Module):
|
||||
def __init__(self, models: List[TemplateModel]):
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
lazy_loading: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model_configs = model_configs
|
||||
self.lazy_loading = lazy_loading
|
||||
if lazy_loading:
|
||||
self.models = None
|
||||
else:
|
||||
models = []
|
||||
for model_config in model_configs:
|
||||
TemplatePipeline.check_vram_config(model_config)
|
||||
model_config.download_if_necessary()
|
||||
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
|
||||
models.append(model)
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
|
||||
def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
|
||||
names = {}
|
||||
@@ -100,6 +120,8 @@ class TemplatePipeline(torch.nn.Module):
|
||||
data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
|
||||
if param == "kv_cache":
|
||||
data = self.merge_kv_cache(data)
|
||||
elif param == "lora":
|
||||
data = merge_lora(data)
|
||||
elif len(data) == 1:
|
||||
data = data[0]
|
||||
else:
|
||||
@@ -125,30 +147,32 @@ class TemplatePipeline(torch.nn.Module):
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
lazy_loading: bool = False,
|
||||
):
|
||||
models = []
|
||||
for model_config in model_configs:
|
||||
TemplatePipeline.check_vram_config(model_config)
|
||||
model_config.download_if_necessary()
|
||||
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
|
||||
models.append(model)
|
||||
pipe = TemplatePipeline(models)
|
||||
pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading)
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, inputs: List[Dict], pipe=None, **kwargs):
|
||||
return [(i.get("model_id", 0), self.models[i.get("model_id", 0)].process_inputs(pipe=pipe, **i)) for i in inputs]
|
||||
|
||||
def forward(self, inputs: List[Tuple[int, Dict]], pipe=None, **kwargs):
|
||||
template_cache = []
|
||||
for model_id, model_inputs in inputs:
|
||||
kv_cache = self.models[model_id](pipe=pipe, **model_inputs)
|
||||
template_cache.append(kv_cache)
|
||||
return template_cache
|
||||
def fetch_model(self, model_id):
|
||||
if self.lazy_loading:
|
||||
model_config = self.model_configs[model_id]
|
||||
model_config.download_if_necessary()
|
||||
model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
model = self.models[model_id]
|
||||
return model
|
||||
|
||||
def call_single_side(self, pipe=None, inputs: List[Dict] = None):
|
||||
inputs = self.process_inputs(pipe=pipe, inputs=inputs)
|
||||
template_cache = self.forward(pipe=pipe, inputs=inputs)
|
||||
model = None
|
||||
onload_model_id = -1
|
||||
template_cache = []
|
||||
for i in inputs:
|
||||
model_id = i.get("model_id", 0)
|
||||
if model_id != onload_model_id:
|
||||
model = self.fetch_model(model_id)
|
||||
onload_model_id = model_id
|
||||
cache = model.process_inputs(pipe=pipe, **i)
|
||||
cache = model.forward(pipe=pipe, **cache)
|
||||
template_cache.append(cache)
|
||||
template_cache = self.merge_template_cache(template_cache)
|
||||
return template_cache
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class GeneralUnit_TemplateProcessInputs(PipelineUnit):
|
||||
self.data_processor = data_processor
|
||||
|
||||
def process(self, pipe, template_inputs):
|
||||
if not hasattr(pipe, "template_model"):
|
||||
if not hasattr(pipe, "template_model") or template_inputs is None:
|
||||
return {}
|
||||
if self.data_processor is not None:
|
||||
template_inputs = self.data_processor(**template_inputs)
|
||||
@@ -58,7 +58,7 @@ class GeneralUnit_TemplateForward(PipelineUnit):
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
|
||||
def process(self, pipe, template_inputs):
|
||||
if not hasattr(pipe, "template_model"):
|
||||
if not hasattr(pipe, "template_model") or template_inputs is None:
|
||||
return {}
|
||||
template_cache = pipe.template_model.forward(
|
||||
**template_inputs,
|
||||
|
||||
@@ -100,6 +100,9 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# LoRA
|
||||
lora = None,
|
||||
negative_lora = None,
|
||||
# Text Embedding
|
||||
extra_text_embedding = None,
|
||||
negative_extra_text_embedding = None,
|
||||
# Inpaint
|
||||
inpaint_mask: Image.Image = None,
|
||||
inpaint_blur_size: int = None,
|
||||
@@ -113,10 +116,12 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"kv_cache": kv_cache,
|
||||
"extra_text_embedding": extra_text_embedding,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
"kv_cache": negative_kv_cache,
|
||||
"extra_text_embedding": negative_extra_text_embedding,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
||||
@@ -607,6 +612,7 @@ def model_fn_flux2(
|
||||
edit_latents=None,
|
||||
edit_image_ids=None,
|
||||
kv_cache=None,
|
||||
extra_text_embedding=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
@@ -617,6 +623,11 @@ def model_fn_flux2(
|
||||
latents = torch.concat([latents, edit_latents], dim=1)
|
||||
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
||||
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
||||
if extra_text_embedding is not None:
|
||||
extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device)
|
||||
extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1])
|
||||
prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1)
|
||||
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
|
||||
model_output = dit(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
|
||||
Reference in New Issue
Block a user