update doc

This commit is contained in:
Artiprocher
2025-11-05 16:08:01 +08:00
parent d27917ad41
commit 3afecc65fc
7 changed files with 336 additions and 8 deletions

View File

@@ -158,7 +158,7 @@ class AutoWrappedModule(AutoTorchModule):
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
else:
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
@@ -167,7 +167,7 @@ class AutoWrappedModule(AutoTorchModule):
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
else:
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
@@ -308,7 +308,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
else:
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
@@ -317,7 +317,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
else:
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2

View File

@@ -1,7 +1,7 @@
from ..core.loader import load_model, hash_model_file
from ..core.vram import AutoWrappedModule
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
import importlib, json
import importlib, json, torch
class ModelPool:
@@ -46,8 +46,23 @@ class ModelPool:
)
return model
def auto_load_model(self, path, vram_config, vram_limit=None):
def default_vram_config(self):
vram_config = {
"offload_dtype": None,
"offload_device": None,
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cpu",
"computation_dtype": torch.bfloat16,
"computation_device": "cpu",
}
return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
model_hash = hash_model_file(path)
loaded = False
for config in MODEL_CONFIGS: