Merge pull request #667 from modelscope/lora_merge

Lora merge
This commit is contained in:
Zhongjie Duan
2025-07-07 13:30:34 +08:00
committed by GitHub
5 changed files with 244 additions and 12 deletions

View File

@@ -21,7 +21,7 @@ from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
@@ -91,12 +91,13 @@ class FluxImagePipeline(BasePipeline):
self.controlnet: MultiControlNet = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder = None
self.unit_runner = PipelineUnitRunner()
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
self.lora_patcher: FluxLoraPatcher = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
self.units = [
FluxImageUnit_ShapeChecker(),
FluxImageUnit_NoiseInitializer(),
@@ -116,10 +117,55 @@ class FluxImagePipeline(BasePipeline):
self.model_fn = model_fn_flux_image
def load_lora(self, module, path, alpha=1):
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str],
alpha=1,
hotload=False,
local_model_path="./models",
skip_download=False
):
if isinstance(lora_config, str):
lora_config = ModelConfig(path=lora_config)
else:
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
lora = loader.convert_state_dict(lora)
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader.load(module, lora, alpha=alpha)
def enable_lora_patcher(self):
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
return
if self.lora_patcher is None:
print("Please load lora patcher models before `enable_lora_patcher()`.")
return
for name, module in self.dit.named_modules():
if isinstance(module, AutoWrappedLinear):
merger_name = name.replace(".", "___")
if merger_name in self.lora_patcher.model_dict:
module.lora_merger = self.lora_patcher.model_dict[merger_name]
def clear_lora(self):
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def training_loss(self, **inputs):
@@ -285,10 +331,10 @@ class FluxImagePipeline(BasePipeline):
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
pipe.qwenvl = model_manager.fetch_model("qwenvl")
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if pipe.image_proj_model is not None:
pipe.infinityou_processor = InfinitYou(device=device)
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
# ControlNet
controlnets = []