mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
support lora fusion
This commit is contained in:
@@ -21,8 +21,7 @@ from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader,LoraPatcher
|
||||
from ..models.lora import FluxLoRAConverter
|
||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
@@ -92,12 +91,13 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.controlnet: MultiControlNet = None
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.image_proj_model: InfiniteYouImageProjector = None
|
||||
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
|
||||
self.lora_patcher: FluxLoraPatcher = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
|
||||
self.units = [
|
||||
FluxImageUnit_ShapeChecker(),
|
||||
FluxImageUnit_NoiseInitializer(),
|
||||
@@ -117,49 +117,55 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.model_fn = model_fn_flux_image
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str],
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
local_model_path="./models",
|
||||
skip_download=False
|
||||
):
|
||||
if isinstance(lora_config, str):
|
||||
lora_config = ModelConfig(path=lora_config)
|
||||
else:
|
||||
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
def enable_lora_hotload(self, lora_paths):
|
||||
# load lora state dict and align format
|
||||
lora_state_dicts = [
|
||||
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths
|
||||
]
|
||||
lora_state_dicts = [l for l in lora_state_dicts if l != {}]
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = loader.convert_state_dict(lora)
|
||||
if hotload:
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
lora_a_name = f'{name}.lora_A.default.weight'
|
||||
lora_b_name = f'{name}.lora_B.default.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
else:
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def enable_lora_patcher(self):
|
||||
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
|
||||
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
|
||||
return
|
||||
if self.lora_patcher is None:
|
||||
print("Please load lora patcher models before `enable_lora_patcher()`.")
|
||||
return
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
lora_a_name = f'{name}.lora_A.default.weight'
|
||||
lora_b_name = f'{name}.lora_B.default.weight'
|
||||
lora_A_weights = []
|
||||
lora_B_weights = []
|
||||
for lora_dict in lora_state_dicts:
|
||||
if lora_a_name in lora_dict and lora_b_name in lora_dict:
|
||||
lora_A_weights.append(lora_dict[lora_a_name])
|
||||
lora_B_weights.append(lora_dict[lora_b_name])
|
||||
module.lora_A_weights = lora_A_weights
|
||||
module.lora_B_weights = lora_B_weights
|
||||
|
||||
|
||||
def enable_lora_patcher(self, lora_patcher_path):
|
||||
# load lora patcher
|
||||
lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device)
|
||||
lora_patcher.load_state_dict(load_state_dict(lora_patcher_path))
|
||||
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
merger_name = name.replace(".", "___")
|
||||
if merger_name in lora_patcher.model_dict:
|
||||
module.lora_merger = lora_patcher.model_dict[merger_name]
|
||||
if merger_name in self.lora_patcher.model_dict:
|
||||
module.lora_merger = self.lora_patcher.model_dict[merger_name]
|
||||
|
||||
|
||||
def off_lora_hotload(self):
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.lora_A_weights = []
|
||||
module.lora_B_weights = []
|
||||
def clear_lora(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
if hasattr(module, "lora_A_weights"):
|
||||
module.lora_A_weights.clear()
|
||||
if hasattr(module, "lora_B_weights"):
|
||||
module.lora_B_weights.clear()
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
@@ -325,10 +331,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
pipe.qwenvl = model_manager.fetch_model("qwenvl")
|
||||
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
|
||||
|
||||
pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||
if pipe.image_proj_model is not None:
|
||||
pipe.infinityou_processor = InfinitYou(device=device)
|
||||
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
|
||||
|
||||
# ControlNet
|
||||
controlnets = []
|
||||
|
||||
Reference in New Issue
Block a user