mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
lora hotload and merge
This commit is contained in:
@@ -21,7 +21,8 @@ from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader
|
||||
from ..lora.flux_lora import FluxLoRALoader,LoraPatcher
|
||||
from ..models.lora import FluxLoRAConverter
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
@@ -121,6 +122,45 @@ class FluxImagePipeline(BasePipeline):
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
def enable_lora_hotload(self, lora_paths):
|
||||
# load lora state dict and align format
|
||||
lora_state_dicts = [
|
||||
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths
|
||||
]
|
||||
lora_state_dicts = [l for l in lora_state_dicts if l != {}]
|
||||
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
lora_a_name = f'{name}.lora_A.default.weight'
|
||||
lora_b_name = f'{name}.lora_B.default.weight'
|
||||
lora_A_weights = []
|
||||
lora_B_weights = []
|
||||
for lora_dict in lora_state_dicts:
|
||||
if lora_a_name in lora_dict and lora_b_name in lora_dict:
|
||||
lora_A_weights.append(lora_dict[lora_a_name])
|
||||
lora_B_weights.append(lora_dict[lora_b_name])
|
||||
module.lora_A_weights = lora_A_weights
|
||||
module.lora_B_weights = lora_B_weights
|
||||
|
||||
|
||||
def enable_lora_patcher(self, lora_patcher_path):
|
||||
# load lora patcher
|
||||
lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device)
|
||||
lora_patcher.load_state_dict(load_state_dict(lora_patcher_path))
|
||||
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
merger_name = name.replace(".", "___")
|
||||
if merger_name in lora_patcher.model_dict:
|
||||
module.lora_merger = lora_patcher.model_dict[merger_name]
|
||||
|
||||
|
||||
def off_lora_hotload(self):
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.lora_A_weights = []
|
||||
module.lora_B_weights = []
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
|
||||
Reference in New Issue
Block a user