mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-15 14:48:23 +00:00
update template framework
This commit is contained in:
@@ -3,7 +3,7 @@ import torch
|
||||
import numpy as np
|
||||
from einops import repeat, reduce
|
||||
from typing import Union
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
@@ -354,6 +354,23 @@ class BasePipeline(torch.nn.Module):
|
||||
self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
self.template_data_processor = load_template_data_processor(model_config.path)()
|
||||
|
||||
|
||||
def enable_lora_hot_loading(self, model: torch.nn.Module):
|
||||
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
|
||||
return model
|
||||
module_map = {torch.nn.Linear: AutoWrappedLinear}
|
||||
vram_config = {
|
||||
"offload_dtype": self.torch_dtype,
|
||||
"offload_device": self.device,
|
||||
"onload_dtype": self.torch_dtype,
|
||||
"onload_device": self.device,
|
||||
"preparing_dtype": self.torch_dtype,
|
||||
"preparing_device": self.device,
|
||||
"computation_dtype": self.torch_dtype,
|
||||
"computation_device": self.device,
|
||||
}
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config)
|
||||
return model
|
||||
|
||||
|
||||
class PipelineUnitGraph:
|
||||
|
||||
Reference in New Issue
Block a user