mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
support vram management in flux
This commit is contained in:
@@ -101,12 +101,22 @@ class BasePipeline(torch.nn.Module):
|
||||
if model_name not in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.cpu()
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
# load the needed models to device
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(self.device)
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
# fresh the cuda cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -11,6 +11,9 @@ from PIL import Image
|
||||
from ..models.tiler import FastTileWorker
|
||||
from transformers import SiglipVisionModel
|
||||
from copy import deepcopy
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
@@ -31,6 +34,105 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
T5LayerNorm: AutoWrappedModule,
|
||||
T5DenseActDense: AutoWrappedModule,
|
||||
T5DenseGatedActDense: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
@@ -62,10 +164,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
||||
return pipe
|
||||
|
||||
Reference in New Issue
Block a user