support lora fusion

This commit is contained in:
Artiprocher
2025-07-03 18:49:46 +08:00
parent 9cb887015b
commit 8a9dbbd3ba
5 changed files with 175 additions and 54 deletions

View File

@@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel
from ..models.step1x_connector import Qwen2Connector from ..models.step1x_connector import Qwen2Connector
from ..lora.flux_lora import FluxLoraPatcher
model_loader_configs = [ model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -144,6 +146,7 @@ model_loader_configs = [
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.

View File

@@ -1,4 +1,4 @@
import torch import torch, math
from diffsynth.lora import GeneralLoRALoader from diffsynth.lora import GeneralLoRALoader
from diffsynth.models.lora import FluxLoRAFromCivitai from diffsynth.models.lora import FluxLoRAFromCivitai
@@ -6,11 +6,69 @@ from diffsynth.models.lora import FluxLoRAFromCivitai
class FluxLoRALoader(GeneralLoRALoader): class FluxLoRALoader(GeneralLoRALoader):
def __init__(self, device="cpu", torch_dtype=torch.float32): def __init__(self, device="cpu", torch_dtype=torch.float32):
super().__init__(device=device, torch_dtype=torch_dtype) super().__init__(device=device, torch_dtype=torch_dtype)
self.loader = FluxLoRAFromCivitai()
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
lora_prefix, model_resource = self.loader.match(model, state_dict_lora) super().load(model, state_dict_lora, alpha)
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)
def convert_state_dict(self, state_dict):
# TODO: support other lora format
rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
}
def guess_block_id(name):
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
return None, None
def guess_alpha(state_dict):
for name, param in state_dict.items():
if ".alpha" in name:
name_ = name.replace(".alpha", ".lora_down.weight")
if name_ in state_dict:
lora_alpha = param.item() / state_dict[name_].shape[0]
lora_alpha = math.sqrt(lora_alpha)
return lora_alpha
return 1
alpha = guess_alpha(state_dict)
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name)
if alpha != 1:
param *= alpha
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
return state_dict_
class LoraMerger(torch.nn.Module): class LoraMerger(torch.nn.Module):
def __init__(self, dim): def __init__(self, dim):
@@ -35,7 +93,8 @@ class LoraMerger(torch.nn.Module):
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output return output
class LoraPatcher(torch.nn.Module):
class FluxLoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None): def __init__(self, lora_patterns=None):
super().__init__() super().__init__()
if lora_patterns is None: if lora_patterns is None:
@@ -69,3 +128,15 @@ class LoraPatcher(torch.nn.Module):
def forward(self, base_output, lora_outputs, name): def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
@staticmethod
def state_dict_converter():
return FluxLoraPatcherStateDictConverter()
class FluxLoraPatcherStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict

View File

@@ -21,8 +21,7 @@ from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.tiler import FastTileWorker from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader,LoraPatcher from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
from ..models.lora import FluxLoRAConverter
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm from ..models.flux_dit import RMSNorm
@@ -92,12 +91,13 @@ class FluxImagePipeline(BasePipeline):
self.controlnet: MultiControlNet = None self.controlnet: MultiControlNet = None
self.ipadapter: FluxIpAdapter = None self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder = None self.ipadapter_image_encoder = None
self.unit_runner = PipelineUnitRunner()
self.qwenvl = None self.qwenvl = None
self.step1x_connector: Qwen2Connector = None self.step1x_connector: Qwen2Connector = None
self.infinityou_processor: InfinitYou = None self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None self.image_proj_model: InfiniteYouImageProjector = None
self.in_iteration_models = ("dit", "step1x_connector", "controlnet") self.lora_patcher: FluxLoraPatcher = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
self.units = [ self.units = [
FluxImageUnit_ShapeChecker(), FluxImageUnit_ShapeChecker(),
FluxImageUnit_NoiseInitializer(), FluxImageUnit_NoiseInitializer(),
@@ -117,49 +117,55 @@ class FluxImagePipeline(BasePipeline):
self.model_fn = model_fn_flux_image self.model_fn = model_fn_flux_image
def load_lora(self, module, path, alpha=1): def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str],
alpha=1,
hotload=False,
local_model_path="./models",
skip_download=False
):
if isinstance(lora_config, str):
lora_config = ModelConfig(path=lora_config)
else:
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device) loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha) lora = loader.convert_state_dict(lora)
if hotload:
def enable_lora_hotload(self, lora_paths): for name, module in module.named_modules():
# load lora state dict and align format if isinstance(module, AutoWrappedLinear):
lora_state_dicts = [ lora_a_name = f'{name}.lora_A.default.weight'
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths lora_b_name = f'{name}.lora_B.default.weight'
] if lora_a_name in lora and lora_b_name in lora:
lora_state_dicts = [l for l in lora_state_dicts if l != {}] module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader.load(module, lora, alpha=alpha)
def enable_lora_patcher(self):
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
return
if self.lora_patcher is None:
print("Please load lora patcher models before `enable_lora_patcher()`.")
return
for name, module in self.dit.named_modules(): for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear): if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
lora_A_weights = []
lora_B_weights = []
for lora_dict in lora_state_dicts:
if lora_a_name in lora_dict and lora_b_name in lora_dict:
lora_A_weights.append(lora_dict[lora_a_name])
lora_B_weights.append(lora_dict[lora_b_name])
module.lora_A_weights = lora_A_weights
module.lora_B_weights = lora_B_weights
def enable_lora_patcher(self, lora_patcher_path):
# load lora patcher
lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device)
lora_patcher.load_state_dict(load_state_dict(lora_patcher_path))
for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear):
merger_name = name.replace(".", "___") merger_name = name.replace(".", "___")
if merger_name in lora_patcher.model_dict: if merger_name in self.lora_patcher.model_dict:
module.lora_merger = lora_patcher.model_dict[merger_name] module.lora_merger = self.lora_patcher.model_dict[merger_name]
def off_lora_hotload(self): def clear_lora(self):
for name, module in self.dit.named_modules(): for name, module in self.named_modules():
if isinstance(module, torch.nn.Linear): if isinstance(module, AutoWrappedLinear):
module.lora_A_weights = [] if hasattr(module, "lora_A_weights"):
module.lora_B_weights = [] module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def training_loss(self, **inputs): def training_loss(self, **inputs):
@@ -325,10 +331,10 @@ class FluxImagePipeline(BasePipeline):
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
pipe.qwenvl = model_manager.fetch_model("qwenvl") pipe.qwenvl = model_manager.fetch_model("qwenvl")
pipe.step1x_connector = model_manager.fetch_model("step1x_connector") pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector") pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if pipe.image_proj_model is not None: if pipe.image_proj_model is not None:
pipe.infinityou_processor = InfinitYou(device=device) pipe.infinityou_processor = InfinitYou(device=device)
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
# ControlNet # ControlNet
controlnets = [] controlnets = []

View File

@@ -124,13 +124,19 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
weight = cast_to(self.weight, self.computation_dtype, self.computation_device) weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias) out = torch.nn.functional.linear(x, weight, bias)
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): if len(self.lora_A_weights) == 0:
out_lora = x @ lora_A.T @ lora_B.T # No LoRA
if self.lora_merger is None: return out
out = out + out_lora elif self.lora_merger is None:
lora_output.append(out_lora) # Native LoRA inference
if self.lora_merger is not None and len(lora_output) > 0: for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
else:
# LoRA fusion
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
lora_output.append(x @ lora_A.T @ lora_B.T)
lora_output = torch.stack(lora_output) lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output) out = self.lora_merger(out, lora_output)
return out return out

View File

@@ -0,0 +1,35 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-LoRAFusion", origin_file_pattern="model.safetensors")
],
)
pipe.enable_vram_management()
pipe.enable_lora_patcher()
pipe.load_lora(
pipe.dit,
ModelConfig(model_id="yangyufeng/fgao", origin_file_pattern="30.safetensors"),
hotload=True
)
pipe.load_lora(
pipe.dit,
ModelConfig(model_id="bobooblue/LoRA-bling-mai", origin_file_pattern="10.safetensors"),
hotload=True
)
pipe.load_lora(
pipe.dit,
ModelConfig(model_id="JIETANGAB/E", origin_file_pattern="17.safetensors"),
hotload=True
)
image = pipe(prompt="a beautiful Asian girl", seed=0)
image.save("flux.jpg")