support lora fusion

This commit is contained in:
Artiprocher
2025-07-03 18:49:46 +08:00
parent 9cb887015b
commit 8a9dbbd3ba
5 changed files with 175 additions and 54 deletions

View File

@@ -21,8 +21,7 @@ from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader,LoraPatcher
from ..models.lora import FluxLoRAConverter
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
@@ -92,12 +91,13 @@ class FluxImagePipeline(BasePipeline):
self.controlnet: MultiControlNet = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder = None
self.unit_runner = PipelineUnitRunner()
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
self.lora_patcher: FluxLoraPatcher = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
self.units = [
FluxImageUnit_ShapeChecker(),
FluxImageUnit_NoiseInitializer(),
@@ -117,49 +117,55 @@ class FluxImagePipeline(BasePipeline):
self.model_fn = model_fn_flux_image
def load_lora(self, module, path, alpha=1):
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str],
alpha=1,
hotload=False,
local_model_path="./models",
skip_download=False
):
if isinstance(lora_config, str):
lora_config = ModelConfig(path=lora_config)
else:
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def enable_lora_hotload(self, lora_paths):
# load lora state dict and align format
lora_state_dicts = [
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths
]
lora_state_dicts = [l for l in lora_state_dicts if l != {}]
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
lora = loader.convert_state_dict(lora)
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader.load(module, lora, alpha=alpha)
def enable_lora_patcher(self):
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
return
if self.lora_patcher is None:
print("Please load lora patcher models before `enable_lora_patcher()`.")
return
for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
lora_A_weights = []
lora_B_weights = []
for lora_dict in lora_state_dicts:
if lora_a_name in lora_dict and lora_b_name in lora_dict:
lora_A_weights.append(lora_dict[lora_a_name])
lora_B_weights.append(lora_dict[lora_b_name])
module.lora_A_weights = lora_A_weights
module.lora_B_weights = lora_B_weights
def enable_lora_patcher(self, lora_patcher_path):
# load lora patcher
lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device)
lora_patcher.load_state_dict(load_state_dict(lora_patcher_path))
for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear):
if isinstance(module, AutoWrappedLinear):
merger_name = name.replace(".", "___")
if merger_name in lora_patcher.model_dict:
module.lora_merger = lora_patcher.model_dict[merger_name]
if merger_name in self.lora_patcher.model_dict:
module.lora_merger = self.lora_patcher.model_dict[merger_name]
def off_lora_hotload(self):
for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear):
module.lora_A_weights = []
module.lora_B_weights = []
def clear_lora(self):
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def training_loss(self, **inputs):
@@ -325,10 +331,10 @@ class FluxImagePipeline(BasePipeline):
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
pipe.qwenvl = model_manager.fetch_model("qwenvl")
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if pipe.image_proj_model is not None:
pipe.infinityou_processor = InfinitYou(device=device)
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
# ControlNet
controlnets = []