diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 96ee86a..c15d940 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -80,7 +80,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device): loaded_model_names, loaded_models = [], [] for model_name, model_class in zip(model_names, model_classes): - model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval() + if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]: + model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval() + else: + model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype) if torch_dtype == torch.float16 and hasattr(model, "half"): model = model.half() try: diff --git a/diffsynth/models/sd3_text_encoder.py b/diffsynth/models/sd3_text_encoder.py index cb9bdcd..efe29ca 100644 --- a/diffsynth/models/sd3_text_encoder.py +++ b/diffsynth/models/sd3_text_encoder.py @@ -9,7 +9,8 @@ class SD3TextEncoder1(SDTextEncoder): super().__init__(vocab_size=vocab_size) def forward(self, input_ids, clip_skip=2, extra_mask=None): - embeds = self.token_embedding(input_ids) + self.position_embeds + embeds = self.token_embedding(input_ids) + embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device) attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) if extra_mask is not None: attn_mask[:, extra_mask[0]==0] = float("-inf") diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 8b247d6..2a4f01c 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -101,12 +101,22 @@ class BasePipeline(torch.nn.Module): if model_name not in loadmodel_names: model = getattr(self, model_name) if model is not None: - model.cpu() + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + else: + model.cpu() # load the needed models to device for model_name in loadmodel_names: model = getattr(self, model_name) if model is not None: - model.to(self.device) + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + else: + model.to(self.device) # fresh the cuda cache torch.cuda.empty_cache() diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 8cd009f..7303dff 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -11,6 +11,9 @@ from PIL import Image from ..models.tiler import FastTileWorker from transformers import SiglipVisionModel from copy import deepcopy +from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense +from ..models.flux_dit import RMSNorm +from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear class FluxImagePipeline(BasePipeline): @@ -31,6 +34,105 @@ class FluxImagePipeline(BasePipeline): self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder'] + def enable_vram_management(self, num_persistent_param_in_dit=None): + dtype = next(iter(self.text_encoder_1.parameters())).dtype + enable_vram_management( + self.text_encoder_1, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.text_encoder_2.parameters())).dtype + enable_vram_management( + self.text_encoder_2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + T5LayerNorm: AutoWrappedModule, + T5DenseActDense: AutoWrappedModule, + T5DenseGatedActDense: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.dit.parameters())).dtype + enable_vram_management( + self.dit, + module_map = { + RMSNorm: AutoWrappedModule, + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cuda", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.vae_decoder.parameters())).dtype + enable_vram_management( + self.vae_decoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.GroupNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.vae_encoder.parameters())).dtype + enable_vram_management( + self.vae_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.GroupNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + def denoising_model(self): return self.dit @@ -62,10 +164,10 @@ class FluxImagePipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None): pipe = FluxImagePipeline( device=model_manager.device if device is None else device, - torch_dtype=model_manager.torch_dtype, + torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype, ) pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes) return pipe diff --git a/diffsynth/vram_management/__init__.py b/diffsynth/vram_management/__init__.py new file mode 100644 index 0000000..69a388d --- /dev/null +++ b/diffsynth/vram_management/__init__.py @@ -0,0 +1 @@ +from .layers import * diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py new file mode 100644 index 0000000..a9df39e --- /dev/null +++ b/diffsynth/vram_management/layers.py @@ -0,0 +1,95 @@ +import torch, copy +from ..models.utils import init_weights_on_device + + +def cast_to(weight, dtype, device): + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + +class AutoWrappedModule(torch.nn.Module): + def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): + super().__init__() + self.module = module.to(dtype=offload_dtype, device=offload_device) + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.state = 0 + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def forward(self, *args, **kwargs): + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + module = self.module + else: + module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) + return module(*args, **kwargs) + + +class AutoWrappedLinear(torch.nn.Linear): + def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): + with init_weights_on_device(device=torch.device("meta")): + super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) + self.weight = module.weight + self.bias = module.bias + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.state = 0 + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def forward(self, x, *args, **kwargs): + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + else: + weight = cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + return torch.nn.functional.linear(x, weight, bias) + + +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): + for name, module in model.named_children(): + for source_module, target_module in module_map.items(): + if isinstance(module, source_module): + num_param = sum(p.numel() for p in module.parameters()) + if max_num_param is not None and total_num_param + num_param > max_num_param: + module_config_ = overflow_module_config + else: + module_config_ = module_config + module_ = target_module(module, **module_config_) + setattr(model, name, module_) + total_num_param += num_param + break + else: + total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) + return total_num_param + + +def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): + enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) + model.vram_management_enabled = True + diff --git a/examples/vram_management/README.md b/examples/vram_management/README.md new file mode 100644 index 0000000..98f9676 --- /dev/null +++ b/examples/vram_management/README.md @@ -0,0 +1,3 @@ +# VRAM Management + +Experimental feature. Still under development. diff --git a/examples/vram_management/flux_text_to_image.py b/examples/vram_management/flux_text_to_image.py new file mode 100644 index 0000000..ea24106 --- /dev/null +++ b/examples/vram_management/flux_text_to_image.py @@ -0,0 +1,25 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline + + +model_manager = ModelManager( + file_path_list=[ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ], + torch_dtype=torch.float8_e4m3fn, + device="cpu" +) +pipe = FluxImagePipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") + +# Enable VRAM management +# `num_persistent_param_in_dit` indicates the number of parameters that reside persistently in VRAM within the DiT model. +# When `num_persistent_param_in_dit=None`, it means all parameters reside persistently in memory. +# When `num_persistent_param_in_dit=7*10**9`, it indicates that 7 billion parameters reside persistently in memory. +# When `num_persistent_param_in_dit=0`, it means no parameters reside persistently in memory, and they are loaded layer by layer during inference. +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +image = pipe(prompt="a beautiful orange cat", seed=0) +image.save("image.jpg")