mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
lora merger
This commit is contained in:
@@ -62,25 +62,26 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
state_dict[k] = state_dict[k].to(device)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -401,7 +401,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
lora_state_dicts=[],
|
||||
lora_alpahs=[]
|
||||
lora_alpahs=[],
|
||||
lora_patcher=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
@@ -443,6 +444,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
hidden_states=latents, timestep=timestep,
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_alpahs = lora_alpahs,
|
||||
lora_patcher=lora_patcher,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
@@ -462,6 +464,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
hidden_states=latents, timestep=timestep,
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_alpahs = lora_alpahs,
|
||||
lora_patcher=lora_patcher,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
@@ -544,6 +547,7 @@ def lets_dance_flux(
|
||||
entity_masks=None,
|
||||
ipadapter_kwargs_list={},
|
||||
tea_cache: TeaCache = None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
@@ -610,6 +614,11 @@ def lets_dance_flux(
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
@@ -622,15 +631,22 @@ def lets_dance_flux(
|
||||
else:
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
|
||||
**kwargs
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), **kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
|
||||
**kwargs
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
@@ -639,15 +655,22 @@ def lets_dance_flux(
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
**kwargs
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), **kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
**kwargs
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
|
||||
@@ -71,15 +71,16 @@ class AutoWrappedLinear(torch.nn.Linear):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
class AutoLoRALinear(torch.nn.Linear):
|
||||
def __init__(self, name='', in_features=1, out_features=2, bias = True, device=None, dtype=None):
|
||||
def __init__(self, name='', in_features=1, out_features=2, bias=True, device=None, dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.name = name
|
||||
|
||||
def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], **kwargs):
|
||||
def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], lora_patcher=None, **kwargs):
|
||||
out = torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
lora_a_name = f'{self.name}.lora_A.weight'
|
||||
lora_b_name = f'{self.name}.lora_B.weight'
|
||||
lora_a_name = f'{self.name}.lora_A.default.weight'
|
||||
lora_b_name = f'{self.name}.lora_B.default.weight'
|
||||
|
||||
lora_output = []
|
||||
for i, lora_state_dict in enumerate(lora_state_dicts):
|
||||
if lora_state_dict is None:
|
||||
break
|
||||
@@ -87,7 +88,10 @@ class AutoLoRALinear(torch.nn.Linear):
|
||||
lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device)
|
||||
lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device)
|
||||
out_lora = x @ lora_A.T @ lora_B.T
|
||||
out = out + out_lora * lora_alpahs[i]
|
||||
lora_output.append(out_lora)
|
||||
if len(lora_output) > 0:
|
||||
lora_output = torch.stack(lora_output)
|
||||
out = lora_patcher(out, lora_output, self.name)
|
||||
return out
|
||||
|
||||
def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''):
|
||||
|
||||
Reference in New Issue
Block a user