lora hotload and merge

This commit is contained in:
lzw478614@alibaba-inc.com
2025-07-02 13:32:24 +08:00
parent d9c812818d
commit 9cb887015b
3 changed files with 114 additions and 3 deletions

View File

@@ -21,7 +21,8 @@ from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader
from ..lora.flux_lora import FluxLoRALoader,LoraPatcher
from ..models.lora import FluxLoRAConverter
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
@@ -121,6 +122,45 @@ class FluxImagePipeline(BasePipeline):
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def enable_lora_hotload(self, lora_paths):
# load lora state dict and align format
lora_state_dicts = [
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths
]
lora_state_dicts = [l for l in lora_state_dicts if l != {}]
for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
lora_A_weights = []
lora_B_weights = []
for lora_dict in lora_state_dicts:
if lora_a_name in lora_dict and lora_b_name in lora_dict:
lora_A_weights.append(lora_dict[lora_a_name])
lora_B_weights.append(lora_dict[lora_b_name])
module.lora_A_weights = lora_A_weights
module.lora_B_weights = lora_B_weights
def enable_lora_patcher(self, lora_patcher_path):
# load lora patcher
lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device)
lora_patcher.load_state_dict(load_state_dict(lora_patcher_path))
for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear):
merger_name = name.replace(".", "___")
if merger_name in lora_patcher.model_dict:
module.lora_merger = lora_patcher.model_dict[merger_name]
def off_lora_hotload(self):
for name, module in self.dit.named_modules():
if isinstance(module, torch.nn.Linear):
module.lora_A_weights = []
module.lora_B_weights = []
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))