support vram management in flux

This commit is contained in:
Artiprocher
2025-02-13 15:11:39 +08:00
parent 46d4616e23
commit 0699212665
8 changed files with 246 additions and 6 deletions

View File

@@ -101,12 +101,22 @@ class BasePipeline(torch.nn.Module):
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.cpu()
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.to(self.device)
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()