diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index f39b87a..0713b7c 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel from ..models.step1x_connector import Qwen2Connector +from ..lora.flux_lora import FluxLoraPatcher + model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -144,6 +146,7 @@ model_loader_configs = [ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), + (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py index 899160f..cc9d725 100644 --- a/diffsynth/lora/flux_lora.py +++ b/diffsynth/lora/flux_lora.py @@ -1,4 +1,4 @@ -import torch +import torch, math from diffsynth.lora import GeneralLoRALoader from diffsynth.models.lora import FluxLoRAFromCivitai @@ -6,8 +6,137 @@ from diffsynth.models.lora import FluxLoRAFromCivitai class FluxLoRALoader(GeneralLoRALoader): def __init__(self, device="cpu", torch_dtype=torch.float32): super().__init__(device=device, torch_dtype=torch_dtype) - self.loader = FluxLoRAFromCivitai() def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): - lora_prefix, model_resource = self.loader.match(model, state_dict_lora) - self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource) \ No newline at end of file + super().load(model, state_dict_lora, alpha) + + def convert_state_dict(self, state_dict): + # TODO: support other lora format + rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + def guess_block_id(name): + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + return None, None + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + name_ = name.replace(".alpha", ".lora_down.weight") + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + return 1 + alpha = guess_alpha(state_dict) + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + return state_dict_ + + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + + +class FluxLoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) + + @staticmethod + def state_dict_converter(): + return FluxLoraPatcherStateDictConverter() + + +class FluxLoraPatcherStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index fe651f9..c8985dc 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -21,7 +21,7 @@ from ..models.flux_ipadapter import FluxIpAdapter from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit -from ..lora.flux_lora import FluxLoRALoader +from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense from ..models.flux_dit import RMSNorm @@ -91,12 +91,13 @@ class FluxImagePipeline(BasePipeline): self.controlnet: MultiControlNet = None self.ipadapter: FluxIpAdapter = None self.ipadapter_image_encoder = None - self.unit_runner = PipelineUnitRunner() self.qwenvl = None self.step1x_connector: Qwen2Connector = None self.infinityou_processor: InfinitYou = None self.image_proj_model: InfiniteYouImageProjector = None - self.in_iteration_models = ("dit", "step1x_connector", "controlnet") + self.lora_patcher: FluxLoraPatcher = None + self.unit_runner = PipelineUnitRunner() + self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher") self.units = [ FluxImageUnit_ShapeChecker(), FluxImageUnit_NoiseInitializer(), @@ -116,10 +117,55 @@ class FluxImagePipeline(BasePipeline): self.model_fn = model_fn_flux_image - def load_lora(self, module, path, alpha=1): + def load_lora( + self, + module: torch.nn.Module, + lora_config: Union[ModelConfig, str], + alpha=1, + hotload=False, + local_model_path="./models", + skip_download=False + ): + if isinstance(lora_config, str): + lora_config = ModelConfig(path=lora_config) + else: + lora_config.download_if_necessary(local_model_path, skip_download=skip_download) loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device) - lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) - loader.load(module, lora, alpha=alpha) + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + lora = loader.convert_state_dict(lora) + if hotload: + for name, module in module.named_modules(): + if isinstance(module, AutoWrappedLinear): + lora_a_name = f'{name}.lora_A.default.weight' + lora_b_name = f'{name}.lora_B.default.weight' + if lora_a_name in lora and lora_b_name in lora: + module.lora_A_weights.append(lora[lora_a_name] * alpha) + module.lora_B_weights.append(lora[lora_b_name]) + else: + loader.load(module, lora, alpha=alpha) + + + def enable_lora_patcher(self): + if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled): + print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.") + return + if self.lora_patcher is None: + print("Please load lora patcher models before `enable_lora_patcher()`.") + return + for name, module in self.dit.named_modules(): + if isinstance(module, AutoWrappedLinear): + merger_name = name.replace(".", "___") + if merger_name in self.lora_patcher.model_dict: + module.lora_merger = self.lora_patcher.model_dict[merger_name] + + + def clear_lora(self): + for name, module in self.named_modules(): + if isinstance(module, AutoWrappedLinear): + if hasattr(module, "lora_A_weights"): + module.lora_A_weights.clear() + if hasattr(module, "lora_B_weights"): + module.lora_B_weights.clear() def training_loss(self, **inputs): @@ -285,10 +331,10 @@ class FluxImagePipeline(BasePipeline): pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") pipe.qwenvl = model_manager.fetch_model("qwenvl") pipe.step1x_connector = model_manager.fetch_model("step1x_connector") - pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector") if pipe.image_proj_model is not None: pipe.infinityou_processor = InfinitYou(device=device) + pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher") # ControlNet controlnets = [] diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index c0beaf8..0ebb054 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -107,6 +107,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.vram_limit = vram_limit self.state = 0 self.name = name + self.lora_A_weights = [] + self.lora_B_weights = [] + self.lora_merger = None def forward(self, x, *args, **kwargs): if self.state == 2: @@ -120,7 +123,23 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): else: weight = cast_to(self.weight, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) - return torch.nn.functional.linear(x, weight, bias) + out = torch.nn.functional.linear(x, weight, bias) + + if len(self.lora_A_weights) == 0: + # No LoRA + return out + elif self.lora_merger is None: + # Native LoRA inference + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + out = out + x @ lora_A.T @ lora_B.T + else: + # LoRA fusion + lora_output = [] + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + lora_output.append(x @ lora_A.T @ lora_B.T) + lora_output = torch.stack(lora_output) + out = self.lora_merger(out, lora_output) + return out def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""): diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py new file mode 100644 index 0000000..68116d0 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py @@ -0,0 +1,35 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-LoRAFusion", origin_file_pattern="model.safetensors") + ], +) +pipe.enable_vram_management() +pipe.enable_lora_patcher() +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="yangyufeng/fgao", origin_file_pattern="30.safetensors"), + hotload=True +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="bobooblue/LoRA-bling-mai", origin_file_pattern="10.safetensors"), + hotload=True +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="JIETANGAB/E", origin_file_pattern="17.safetensors"), + hotload=True +) + +image = pipe(prompt="This is a digital painting in a soft, ethereal style. a beautiful Asian girl Shine like a diamond. Everywhere is shining with bling bling luster.The background is a textured blue with visible brushstrokes, giving the image an impressionistic style reminiscent of Vincent van Gogh's work", seed=0) +image.save("flux.jpg")