From 9cb887015be703f9344542efe8f6c723d14a29f9 Mon Sep 17 00:00:00 2001 From: "lzw478614@alibaba-inc.com" Date: Wed, 2 Jul 2025 13:32:24 +0800 Subject: [PATCH] lora hotload and merge --- diffsynth/lora/flux_lora.py | 60 ++++++++++++++++++++++++++- diffsynth/pipelines/flux_image_new.py | 42 ++++++++++++++++++- diffsynth/vram_management/layers.py | 15 ++++++- 3 files changed, 114 insertions(+), 3 deletions(-) diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py index 899160f..b0f17aa 100644 --- a/diffsynth/lora/flux_lora.py +++ b/diffsynth/lora/flux_lora.py @@ -10,4 +10,62 @@ class FluxLoRALoader(GeneralLoRALoader): def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): lora_prefix, model_resource = self.loader.match(model, state_dict_lora) - self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource) \ No newline at end of file + self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource) + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + +class LoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index fe651f9..9c07f89 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -21,7 +21,8 @@ from ..models.flux_ipadapter import FluxIpAdapter from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit -from ..lora.flux_lora import FluxLoRALoader +from ..lora.flux_lora import FluxLoRALoader,LoraPatcher +from ..models.lora import FluxLoRAConverter from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense from ..models.flux_dit import RMSNorm @@ -121,6 +122,45 @@ class FluxImagePipeline(BasePipeline): lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) loader.load(module, lora, alpha=alpha) + def enable_lora_hotload(self, lora_paths): + # load lora state dict and align format + lora_state_dicts = [ + FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths + ] + lora_state_dicts = [l for l in lora_state_dicts if l != {}] + + for name, module in self.dit.named_modules(): + if isinstance(module, torch.nn.Linear): + lora_a_name = f'{name}.lora_A.default.weight' + lora_b_name = f'{name}.lora_B.default.weight' + lora_A_weights = [] + lora_B_weights = [] + for lora_dict in lora_state_dicts: + if lora_a_name in lora_dict and lora_b_name in lora_dict: + lora_A_weights.append(lora_dict[lora_a_name]) + lora_B_weights.append(lora_dict[lora_b_name]) + module.lora_A_weights = lora_A_weights + module.lora_B_weights = lora_B_weights + + + def enable_lora_patcher(self, lora_patcher_path): + # load lora patcher + lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device) + lora_patcher.load_state_dict(load_state_dict(lora_patcher_path)) + + for name, module in self.dit.named_modules(): + if isinstance(module, torch.nn.Linear): + merger_name = name.replace(".", "___") + if merger_name in lora_patcher.model_dict: + module.lora_merger = lora_patcher.model_dict[merger_name] + + + def off_lora_hotload(self): + for name, module in self.dit.named_modules(): + if isinstance(module, torch.nn.Linear): + module.lora_A_weights = [] + module.lora_B_weights = [] + def training_loss(self, **inputs): timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index c0beaf8..4dfec12 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -107,6 +107,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.vram_limit = vram_limit self.state = 0 self.name = name + self.lora_A_weights = [] + self.lora_B_weights = [] + self.lora_merger = None def forward(self, x, *args, **kwargs): if self.state == 2: @@ -120,7 +123,17 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): else: weight = cast_to(self.weight, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) - return torch.nn.functional.linear(x, weight, bias) + out = torch.nn.functional.linear(x, weight, bias) + lora_output = [] + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + out_lora = x @ lora_A.T @ lora_B.T + if self.lora_merger is None: + out = out + out_lora + lora_output.append(out_lora) + if self.lora_merger is not None and len(lora_output) > 0: + lora_output = torch.stack(lora_output) + out = self.lora_merger(out, lora_output) + return out def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):