diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 5c6ec2c..cbbadcd 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -1,3 +1,14 @@ +flux_general_vram_config = { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule", +} + VRAM_MANAGEMENT_MODULE_MAPS = { "diffsynth.models.qwen_image_dit.QwenImageDiT": { "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", @@ -115,4 +126,28 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config, + "diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config, + "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config, + "diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config, + "diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config, + "diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config, + "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config, + "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py index 1ab0fbd..56fa7d3 100644 --- a/diffsynth/core/loader/model.py +++ b/diffsynth/core/loader/model.py @@ -17,11 +17,19 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic if module_map is not None: devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] device = [d for d in devices if d != "disk"][0] - disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] + dtype = [d for d in dtypes if d != "disk"][0] if vram_config["offload_device"] != "disk": - state_dict = {i: disk_map[i].to(vram_config["offload_dtype"]) for i in disk_map} + state_dict = DiskMap(path, device, torch_dtype=dtype) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + else: + state_dict = {i: state_dict[i] for i in state_dict} model.load_state_dict(state_dict, assign=True) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) + else: + disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) else: # Why do we use `DiskMap`? # Sometimes a model file contains multiple models, diff --git a/diffsynth/models/flux_value_control.py b/diffsynth/models/flux_value_control.py index 691f9ca..549dbc9 100644 --- a/diffsynth/models/flux_value_control.py +++ b/diffsynth/models/flux_value_control.py @@ -30,12 +30,6 @@ class SingleValueEncoder(torch.nn.Module): self.positional_embedding = torch.nn.Parameter( torch.randn(self.prefer_len, dim_out) ) - self._initialize_weights() - - def _initialize_weights(self): - last_linear = self.prefer_value_embedder[-1] - torch.nn.init.zeros_(last_linear.weight) - torch.nn.init.zeros_(last_linear.bias) def forward(self, value, dtype): value = value * 1000 diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index a198e06..d5fc30e 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -105,6 +105,8 @@ class FluxImagePipeline(BasePipeline): self.lora_loader = FluxLoRALoader def enable_lora_merger(self): + if not (hasattr(self.dit, "vram_management_enabled") and getattr(self.dit, "vram_management_enabled")): + raise ValueError("DiT VRAM management is not enabled.") if self.lora_patcher is not None: for name, module in self.dit.named_modules(): if isinstance(module, AutoWrappedLinear): @@ -141,7 +143,9 @@ class FluxImagePipeline(BasePipeline): pipe.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_config.path) value_controllers = model_pool.fetch_model("flux_value_controller") - if value_controllers is not None: pipe.value_controller = MultiValueEncoder(value_controllers) + if value_controllers is not None: + pipe.value_controller = MultiValueEncoder(value_controllers) + pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled controlnets = model_pool.fetch_model("flux_controlnet") if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets) pipe.ipadapter = model_pool.fetch_model("flux_ipadapter") diff --git a/diffsynth/utils/lora/merge.py b/diffsynth/utils/lora/merge.py index d75ff90..61904ff 100644 --- a/diffsynth/utils/lora/merge.py +++ b/diffsynth/utils/lora/merge.py @@ -8,13 +8,13 @@ def merge_lora_weight(tensors_A, tensors_B): return lora_A, lora_B -def merge_lora(loras: List[Dict[str, torch.Tensor]]): +def merge_lora(loras: List[Dict[str, torch.Tensor]], alpha=1): lora_merged = {} keys = [i for i in loras[0].keys() if ".lora_A." in i] for key in keys: tensors_A = [lora[key] for lora in loras] tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras] lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B) - lora_merged[key] = lora_A + lora_merged[key] = lora_A * alpha lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B return lora_merged diff --git a/examples/flux/model_inference_low_vram/FLEX.2-preview.py b/examples/flux/model_inference_low_vram/FLEX.2-preview.py index b90280e..d3071f5 100644 --- a/examples/flux/model_inference_low_vram/FLEX.2-preview.py +++ b/examples/flux/model_inference_low_vram/FLEX.2-preview.py @@ -6,8 +6,8 @@ from PIL import Image vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py b/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py index 1922b24..2994a33 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py @@ -4,8 +4,8 @@ from PIL import Image vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py b/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py index 077f6d2..2ceb064 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py @@ -3,8 +3,8 @@ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py index c4d94cd..e0226ba 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py @@ -3,8 +3,8 @@ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py index 93b410c..61ac25f 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py @@ -5,8 +5,8 @@ from PIL import Image vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py index 5603858..148e7ef 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py @@ -5,8 +5,8 @@ from modelscope import snapshot_download vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py index f6a95fe..ca7c72c 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py @@ -3,8 +3,8 @@ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, Contr vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py index ef2b5c4..da0d7ca 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py @@ -6,8 +6,8 @@ from modelscope import dataset_snapshot_download vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py index ac7ed8e..59f2e9f 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py @@ -3,8 +3,8 @@ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py index c3be51c..119856a 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py @@ -10,8 +10,8 @@ import numpy as np # Please install the following packages. # pip install facexlib insightface onnxruntime vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py index 1fd3502..5928af0 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py @@ -3,8 +3,8 @@ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py index 5371830..ce587cd 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py @@ -3,8 +3,8 @@ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev.py b/examples/flux/model_inference_low_vram/FLUX.1-dev.py index adb168c..ffaf181 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev.py @@ -3,8 +3,8 @@ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py index 7c2be84..1b3050f 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py @@ -13,8 +13,8 @@ else: vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py index 05c2caf..8372fcb 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py @@ -11,8 +11,8 @@ else: vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/flux/model_inference_low_vram/Step1X-Edit.py b/examples/flux/model_inference_low_vram/Step1X-Edit.py index 47c1535..9a3bde8 100644 --- a/examples/flux/model_inference_low_vram/Step1X-Edit.py +++ b/examples/flux/model_inference_low_vram/Step1X-Edit.py @@ -5,8 +5,8 @@ import numpy as np vram_config = { - "onload_dtype": torch.float8_e4m3fn, - "onload_device": "cpu", + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", "preparing_dtype": torch.float8_e4m3fn, diff --git a/examples/test/run.py b/examples/test/run.py index 4e66ba1..93d4488 100644 --- a/examples/test/run.py +++ b/examples/test/run.py @@ -1,4 +1,5 @@ import os, shutil, multiprocessing, time +NUM_GPUS = 7 def script_is_processed(output_path, script): @@ -63,7 +64,7 @@ def run_train_multi_GPU(script_path): def run_train_single_GPU(script_path): tasks = filter_unprocessed_tasks(script_path) - processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, 8)) for i in range(8)] + processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, NUM_GPUS)) for i in range(NUM_GPUS)] for p in processes: p.start() for p in processes: