update template framework

This commit is contained in:
Artiprocher
2026-04-15 14:07:51 +08:00
parent 9f8c352a15
commit 59b4bbb62c
7 changed files with 85 additions and 24 deletions

View File

@@ -3,7 +3,7 @@ import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
@@ -354,6 +354,23 @@ class BasePipeline(torch.nn.Module):
self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
self.template_data_processor = load_template_data_processor(model_config.path)()
def enable_lora_hot_loading(self, model: torch.nn.Module):
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
return model
module_map = {torch.nn.Linear: AutoWrappedLinear}
vram_config = {
"offload_dtype": self.torch_dtype,
"offload_device": self.device,
"onload_dtype": self.torch_dtype,
"onload_device": self.device,
"preparing_dtype": self.torch_dtype,
"preparing_device": self.device,
"computation_dtype": self.torch_dtype,
"computation_device": self.device,
}
model = enable_vram_management(model, module_map, vram_config=vram_config)
return model
class PipelineUnitGraph: