mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-15 14:48:23 +00:00
update template framework
This commit is contained in:
@@ -2,6 +2,7 @@ import torch, os, importlib, warnings, json, inspect
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from ..core import ModelConfig, load_model
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..utils.lora.merge import merge_lora
|
||||
|
||||
|
||||
KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
|
||||
@@ -74,9 +75,28 @@ def load_template_data_processor(path):
|
||||
|
||||
|
||||
class TemplatePipeline(torch.nn.Module):
|
||||
def __init__(self, models: List[TemplateModel]):
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
lazy_loading: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model_configs = model_configs
|
||||
self.lazy_loading = lazy_loading
|
||||
if lazy_loading:
|
||||
self.models = None
|
||||
else:
|
||||
models = []
|
||||
for model_config in model_configs:
|
||||
TemplatePipeline.check_vram_config(model_config)
|
||||
model_config.download_if_necessary()
|
||||
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
|
||||
models.append(model)
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
|
||||
def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
|
||||
names = {}
|
||||
@@ -100,6 +120,8 @@ class TemplatePipeline(torch.nn.Module):
|
||||
data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
|
||||
if param == "kv_cache":
|
||||
data = self.merge_kv_cache(data)
|
||||
elif param == "lora":
|
||||
data = merge_lora(data)
|
||||
elif len(data) == 1:
|
||||
data = data[0]
|
||||
else:
|
||||
@@ -125,30 +147,32 @@ class TemplatePipeline(torch.nn.Module):
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
lazy_loading: bool = False,
|
||||
):
|
||||
models = []
|
||||
for model_config in model_configs:
|
||||
TemplatePipeline.check_vram_config(model_config)
|
||||
model_config.download_if_necessary()
|
||||
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
|
||||
models.append(model)
|
||||
pipe = TemplatePipeline(models)
|
||||
pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading)
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, inputs: List[Dict], pipe=None, **kwargs):
|
||||
return [(i.get("model_id", 0), self.models[i.get("model_id", 0)].process_inputs(pipe=pipe, **i)) for i in inputs]
|
||||
|
||||
def forward(self, inputs: List[Tuple[int, Dict]], pipe=None, **kwargs):
|
||||
template_cache = []
|
||||
for model_id, model_inputs in inputs:
|
||||
kv_cache = self.models[model_id](pipe=pipe, **model_inputs)
|
||||
template_cache.append(kv_cache)
|
||||
return template_cache
|
||||
def fetch_model(self, model_id):
|
||||
if self.lazy_loading:
|
||||
model_config = self.model_configs[model_id]
|
||||
model_config.download_if_necessary()
|
||||
model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
model = self.models[model_id]
|
||||
return model
|
||||
|
||||
def call_single_side(self, pipe=None, inputs: List[Dict] = None):
|
||||
inputs = self.process_inputs(pipe=pipe, inputs=inputs)
|
||||
template_cache = self.forward(pipe=pipe, inputs=inputs)
|
||||
model = None
|
||||
onload_model_id = -1
|
||||
template_cache = []
|
||||
for i in inputs:
|
||||
model_id = i.get("model_id", 0)
|
||||
if model_id != onload_model_id:
|
||||
model = self.fetch_model(model_id)
|
||||
onload_model_id = model_id
|
||||
cache = model.process_inputs(pipe=pipe, **i)
|
||||
cache = model.forward(pipe=pipe, **cache)
|
||||
template_cache.append(cache)
|
||||
template_cache = self.merge_template_cache(template_cache)
|
||||
return template_cache
|
||||
|
||||
|
||||
Reference in New Issue
Block a user