mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 10:48:11 +00:00
lora hotload and merge
This commit is contained in:
@@ -11,3 +11,61 @@ class FluxLoRALoader(GeneralLoRALoader):
|
|||||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
lora_prefix, model_resource = self.loader.match(model, state_dict_lora)
|
lora_prefix, model_resource = self.loader.match(model, state_dict_lora)
|
||||||
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)
|
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)
|
||||||
|
|
||||||
|
class LoraMerger(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||||
|
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.activation = torch.nn.Sigmoid()
|
||||||
|
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs):
|
||||||
|
norm_base_output = self.norm_base(base_output)
|
||||||
|
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||||
|
gate = self.activation(
|
||||||
|
norm_base_output * self.weight_base \
|
||||||
|
+ norm_lora_outputs * self.weight_lora \
|
||||||
|
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||||
|
)
|
||||||
|
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class LoraPatcher(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns=None):
|
||||||
|
super().__init__()
|
||||||
|
if lora_patterns is None:
|
||||||
|
lora_patterns = self.default_lora_patterns()
|
||||||
|
model_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||||
|
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||||
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||||
|
|
||||||
|
def default_lora_patterns(self):
|
||||||
|
lora_patterns = []
|
||||||
|
lora_dict = {
|
||||||
|
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||||
|
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||||
|
}
|
||||||
|
for i in range(19):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||||
|
for i in range(38):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"single_blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
return lora_patterns
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs, name):
|
||||||
|
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ from ..models.flux_ipadapter import FluxIpAdapter
|
|||||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||||
from ..models.tiler import FastTileWorker
|
from ..models.tiler import FastTileWorker
|
||||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||||
from ..lora.flux_lora import FluxLoRALoader
|
from ..lora.flux_lora import FluxLoRALoader,LoraPatcher
|
||||||
|
from ..models.lora import FluxLoRAConverter
|
||||||
|
|
||||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||||
from ..models.flux_dit import RMSNorm
|
from ..models.flux_dit import RMSNorm
|
||||||
@@ -121,6 +122,45 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
loader.load(module, lora, alpha=alpha)
|
loader.load(module, lora, alpha=alpha)
|
||||||
|
|
||||||
|
def enable_lora_hotload(self, lora_paths):
|
||||||
|
# load lora state dict and align format
|
||||||
|
lora_state_dicts = [
|
||||||
|
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths
|
||||||
|
]
|
||||||
|
lora_state_dicts = [l for l in lora_state_dicts if l != {}]
|
||||||
|
|
||||||
|
for name, module in self.dit.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
lora_a_name = f'{name}.lora_A.default.weight'
|
||||||
|
lora_b_name = f'{name}.lora_B.default.weight'
|
||||||
|
lora_A_weights = []
|
||||||
|
lora_B_weights = []
|
||||||
|
for lora_dict in lora_state_dicts:
|
||||||
|
if lora_a_name in lora_dict and lora_b_name in lora_dict:
|
||||||
|
lora_A_weights.append(lora_dict[lora_a_name])
|
||||||
|
lora_B_weights.append(lora_dict[lora_b_name])
|
||||||
|
module.lora_A_weights = lora_A_weights
|
||||||
|
module.lora_B_weights = lora_B_weights
|
||||||
|
|
||||||
|
|
||||||
|
def enable_lora_patcher(self, lora_patcher_path):
|
||||||
|
# load lora patcher
|
||||||
|
lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
lora_patcher.load_state_dict(load_state_dict(lora_patcher_path))
|
||||||
|
|
||||||
|
for name, module in self.dit.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
merger_name = name.replace(".", "___")
|
||||||
|
if merger_name in lora_patcher.model_dict:
|
||||||
|
module.lora_merger = lora_patcher.model_dict[merger_name]
|
||||||
|
|
||||||
|
|
||||||
|
def off_lora_hotload(self):
|
||||||
|
for name, module in self.dit.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
module.lora_A_weights = []
|
||||||
|
module.lora_B_weights = []
|
||||||
|
|
||||||
|
|
||||||
def training_loss(self, **inputs):
|
def training_loss(self, **inputs):
|
||||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||||
|
|||||||
@@ -107,6 +107,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
self.vram_limit = vram_limit
|
self.vram_limit = vram_limit
|
||||||
self.state = 0
|
self.state = 0
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.lora_A_weights = []
|
||||||
|
self.lora_B_weights = []
|
||||||
|
self.lora_merger = None
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
if self.state == 2:
|
if self.state == 2:
|
||||||
@@ -120,7 +123,17 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
else:
|
else:
|
||||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
out = torch.nn.functional.linear(x, weight, bias)
|
||||||
|
lora_output = []
|
||||||
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
out_lora = x @ lora_A.T @ lora_B.T
|
||||||
|
if self.lora_merger is None:
|
||||||
|
out = out + out_lora
|
||||||
|
lora_output.append(out_lora)
|
||||||
|
if self.lora_merger is not None and len(lora_output) > 0:
|
||||||
|
lora_output = torch.stack(lora_output)
|
||||||
|
out = self.lora_merger(out, lora_output)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
|
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
|
||||||
|
|||||||
Reference in New Issue
Block a user