mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
Merge pull request #324 from modelscope/vram_management
support vram management in flux
This commit is contained in:
@@ -80,7 +80,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
|||||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||||
loaded_model_names, loaded_models = [], []
|
loaded_model_names, loaded_models = [], []
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
for model_name, model_class in zip(model_names, model_classes):
|
||||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||||
|
else:
|
||||||
|
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
||||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ class SD3TextEncoder1(SDTextEncoder):
|
|||||||
super().__init__(vocab_size=vocab_size)
|
super().__init__(vocab_size=vocab_size)
|
||||||
|
|
||||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
embeds = self.token_embedding(input_ids)
|
||||||
|
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||||
if extra_mask is not None:
|
if extra_mask is not None:
|
||||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||||
|
|||||||
@@ -101,12 +101,22 @@ class BasePipeline(torch.nn.Module):
|
|||||||
if model_name not in loadmodel_names:
|
if model_name not in loadmodel_names:
|
||||||
model = getattr(self, model_name)
|
model = getattr(self, model_name)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
model.cpu()
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "offload"):
|
||||||
|
module.offload()
|
||||||
|
else:
|
||||||
|
model.cpu()
|
||||||
# load the needed models to device
|
# load the needed models to device
|
||||||
for model_name in loadmodel_names:
|
for model_name in loadmodel_names:
|
||||||
model = getattr(self, model_name)
|
model = getattr(self, model_name)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
model.to(self.device)
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "onload"):
|
||||||
|
module.onload()
|
||||||
|
else:
|
||||||
|
model.to(self.device)
|
||||||
# fresh the cuda cache
|
# fresh the cuda cache
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ from PIL import Image
|
|||||||
from ..models.tiler import FastTileWorker
|
from ..models.tiler import FastTileWorker
|
||||||
from transformers import SiglipVisionModel
|
from transformers import SiglipVisionModel
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||||
|
from ..models.flux_dit import RMSNorm
|
||||||
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||||
|
|
||||||
|
|
||||||
class FluxImagePipeline(BasePipeline):
|
class FluxImagePipeline(BasePipeline):
|
||||||
@@ -31,6 +34,105 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
|
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.text_encoder_1,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.text_encoder_2,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
|
T5LayerNorm: AutoWrappedModule,
|
||||||
|
T5DenseActDense: AutoWrappedModule,
|
||||||
|
T5DenseGatedActDense: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dtype = next(iter(self.dit.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.dit,
|
||||||
|
module_map = {
|
||||||
|
RMSNorm: AutoWrappedModule,
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cuda",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
max_num_param=num_persistent_param_in_dit,
|
||||||
|
overflow_module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.vae_decoder,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
|
torch.nn.GroupNorm: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.vae_encoder,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
|
torch.nn.GroupNorm: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.enable_cpu_offload()
|
||||||
|
|
||||||
|
|
||||||
def denoising_model(self):
|
def denoising_model(self):
|
||||||
return self.dit
|
return self.dit
|
||||||
|
|
||||||
@@ -62,10 +164,10 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||||
pipe = FluxImagePipeline(
|
pipe = FluxImagePipeline(
|
||||||
device=model_manager.device if device is None else device,
|
device=model_manager.device if device is None else device,
|
||||||
torch_dtype=model_manager.torch_dtype,
|
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
|
||||||
)
|
)
|
||||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
||||||
return pipe
|
return pipe
|
||||||
|
|||||||
1
diffsynth/vram_management/__init__.py
Normal file
1
diffsynth/vram_management/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .layers import *
|
||||||
95
diffsynth/vram_management/layers.py
Normal file
95
diffsynth/vram_management/layers.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import torch, copy
|
||||||
|
from ..models.utils import init_weights_on_device
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to(weight, dtype, device):
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedModule(torch.nn.Module):
|
||||||
|
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||||
|
self.offload_dtype = offload_dtype
|
||||||
|
self.offload_device = offload_device
|
||||||
|
self.onload_dtype = onload_dtype
|
||||||
|
self.onload_device = onload_device
|
||||||
|
self.computation_dtype = computation_dtype
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||||
|
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||||
|
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||||
|
module = self.module
|
||||||
|
else:
|
||||||
|
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||||
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedLinear(torch.nn.Linear):
|
||||||
|
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||||
|
with init_weights_on_device(device=torch.device("meta")):
|
||||||
|
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||||
|
self.weight = module.weight
|
||||||
|
self.bias = module.bias
|
||||||
|
self.offload_dtype = offload_dtype
|
||||||
|
self.offload_device = offload_device
|
||||||
|
self.onload_dtype = onload_dtype
|
||||||
|
self.onload_device = onload_device
|
||||||
|
self.computation_dtype = computation_dtype
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
else:
|
||||||
|
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
|
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
for source_module, target_module in module_map.items():
|
||||||
|
if isinstance(module, source_module):
|
||||||
|
num_param = sum(p.numel() for p in module.parameters())
|
||||||
|
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
||||||
|
module_config_ = overflow_module_config
|
||||||
|
else:
|
||||||
|
module_config_ = module_config
|
||||||
|
module_ = target_module(module, **module_config_)
|
||||||
|
setattr(model, name, module_)
|
||||||
|
total_num_param += num_param
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
||||||
|
return total_num_param
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
||||||
|
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
||||||
|
model.vram_management_enabled = True
|
||||||
|
|
||||||
3
examples/vram_management/README.md
Normal file
3
examples/vram_management/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# VRAM Management
|
||||||
|
|
||||||
|
Experimental feature. Still under development.
|
||||||
25
examples/vram_management/flux_text_to_image.py
Normal file
25
examples/vram_management/flux_text_to_image.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline
|
||||||
|
|
||||||
|
|
||||||
|
model_manager = ModelManager(
|
||||||
|
file_path_list=[
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float8_e4m3fn,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
# Enable VRAM management
|
||||||
|
# `num_persistent_param_in_dit` indicates the number of parameters that reside persistently in VRAM within the DiT model.
|
||||||
|
# When `num_persistent_param_in_dit=None`, it means all parameters reside persistently in memory.
|
||||||
|
# When `num_persistent_param_in_dit=7*10**9`, it indicates that 7 billion parameters reside persistently in memory.
|
||||||
|
# When `num_persistent_param_in_dit=0`, it means no parameters reside persistently in memory, and they are loaded layer by layer during inference.
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||||
|
|
||||||
|
image = pipe(prompt="a beautiful orange cat", seed=0)
|
||||||
|
image.save("image.jpg")
|
||||||
Reference in New Issue
Block a user