support vram management in flux

This commit is contained in:
Artiprocher
2025-02-13 15:11:39 +08:00
parent 46d4616e23
commit 0699212665
8 changed files with 246 additions and 6 deletions

View File

@@ -11,6 +11,9 @@ from PIL import Image
from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel
from copy import deepcopy
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class FluxImagePipeline(BasePipeline):
@@ -31,6 +34,105 @@ class FluxImagePipeline(BasePipeline):
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
def enable_vram_management(self, num_persistent_param_in_dit=None):
dtype = next(iter(self.text_encoder_1.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
T5DenseActDense: AutoWrappedModule,
T5DenseGatedActDense: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae_decoder.parameters())).dtype
enable_vram_management(
self.vae_decoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae_encoder.parameters())).dtype
enable_vram_management(
self.vae_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def denoising_model(self):
return self.dit
@@ -62,10 +164,10 @@ class FluxImagePipeline(BasePipeline):
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
pipe = FluxImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe