diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6d3100d..85005ca 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -41,6 +41,30 @@ class RoPEEmbedding(torch.nn.Module): emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) return emb.unsqueeze(1) +class AdaLayerNorm(torch.nn.Module): + def __init__(self, dim, single=False, dual=False): + super().__init__() + self.single = single + self.dual = dual + self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual]) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb, **kwargs): + emb = self.linear(torch.nn.functional.silu(emb),**kwargs) + if self.single: + scale, shift = emb.unsqueeze(1).chunk(2, dim=2) + x = self.norm(x) * (1 + scale) + shift + return x + elif self.dual: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2) + norm_x = self.norm(x) + x = norm_x * (1 + scale_msa) + shift_msa + norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2 + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2 + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class FluxJointAttention(torch.nn.Module): @@ -70,17 +94,17 @@ class FluxJointAttention(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs): batch_size = hidden_states_a.shape[0] # Part A - qkv_a = self.a_to_qkv(hidden_states_a) + qkv_a = self.a_to_qkv(hidden_states_a,**kwargs) qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) q_a, k_a, v_a = qkv_a.chunk(3, dim=1) q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) # Part B - qkv_b = self.b_to_qkv(hidden_states_b) + qkv_b = self.b_to_qkv(hidden_states_b,**kwargs) qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) q_b, k_b, v_b = qkv_b.chunk(3, dim=1) q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b) @@ -97,13 +121,25 @@ class FluxJointAttention(torch.nn.Module): hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] if ipadapter_kwargs_list is not None: hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list) - hidden_states_a = self.a_to_out(hidden_states_a) + hidden_states_a = self.a_to_out(hidden_states_a,**kwargs) if self.only_out_a: return hidden_states_a else: - hidden_states_b = self.b_to_out(hidden_states_b) + hidden_states_b = self.b_to_out(hidden_states_b,**kwargs) return hidden_states_a, hidden_states_b +class AutoSequential(torch.nn.Sequential): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def forward(self, input, **kwargs): + for module in self: + + if isinstance(module, torch.nn.Linear): + # print("##"*10) + input = module(input, **kwargs) + else: + input = module(input) + return input class FluxJointTransformerBlock(torch.nn.Module): @@ -120,6 +156,11 @@ class FluxJointTransformerBlock(torch.nn.Module): torch.nn.GELU(approximate="tanh"), torch.nn.Linear(dim*4, dim) ) + # self.ff_a = AutoSequential( + # torch.nn.Linear(dim, dim*4), + # torch.nn.GELU(approximate="tanh"), + # torch.nn.Linear(dim*4, dim) + # ) self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_b = torch.nn.Sequential( @@ -127,14 +168,18 @@ class FluxJointTransformerBlock(torch.nn.Module): torch.nn.GELU(approximate="tanh"), torch.nn.Linear(dim*4, dim) ) + # self.ff_b = AutoSequential( + # torch.nn.Linear(dim, dim*4), + # torch.nn.GELU(approximate="tanh"), + # torch.nn.Linear(dim*4, dim) + # ) - - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): - norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) - norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs): + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb, **kwargs) + norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb, **kwargs) # Attention - attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list, **kwargs) # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a @@ -149,7 +194,6 @@ class FluxJointTransformerBlock(torch.nn.Module): return hidden_states_a, hidden_states_b - class FluxSingleAttention(torch.nn.Module): def __init__(self, dim_a, dim_b, num_heads, head_dim): super().__init__() @@ -170,10 +214,10 @@ class FluxSingleAttention(torch.nn.Module): return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def forward(self, hidden_states, image_rotary_emb): + def forward(self, hidden_states, image_rotary_emb, **kwargs): batch_size = hidden_states.shape[0] - qkv_a = self.a_to_qkv(hidden_states) + qkv_a = self.a_to_qkv(hidden_states,**kwargs) qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) q_a, k_a, v = qkv_a.chunk(3, dim=1) q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) @@ -195,8 +239,8 @@ class AdaLayerNormSingle(torch.nn.Module): self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - def forward(self, x, emb): - emb = self.linear(self.silu(emb)) + def forward(self, x, emb, **kwargs): + emb = self.linear(self.silu(emb),**kwargs) shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa @@ -226,7 +270,7 @@ class FluxSingleTransformerBlock(torch.nn.Module): return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs): batch_size = hidden_states.shape[0] qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) @@ -243,17 +287,17 @@ class FluxSingleTransformerBlock(torch.nn.Module): return hidden_states - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs): residual = hidden_states_a - norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) - hidden_states_a = self.to_qkv_mlp(norm_hidden_states) + norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb, **kwargs) + hidden_states_a = self.to_qkv_mlp(norm_hidden_states, **kwargs) attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] - attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list, **kwargs) mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) - hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) + hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a, **kwargs) hidden_states_a = residual + hidden_states_a return hidden_states_a, hidden_states_b @@ -267,14 +311,13 @@ class AdaLayerNormContinuous(torch.nn.Module): self.linear = torch.nn.Linear(dim, dim * 2, bias=True) self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) - def forward(self, x, conditioning): - emb = self.linear(self.silu(conditioning)) + def forward(self, x, conditioning, **kwargs): + emb = self.linear(self.silu(conditioning),**kwargs) scale, shift = torch.chunk(emb, 2, dim=1) x = self.norm(x) * (1 + scale)[:, None] + shift[:, None] return x - class FluxDiT(torch.nn.Module): def __init__(self, disable_guidance_embedder=False): super().__init__() @@ -282,6 +325,8 @@ class FluxDiT(torch.nn.Module): self.time_embedder = TimestepEmbeddings(256, 3072) self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + + # self.pooled_text_embedder = AutoSequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) self.context_embedder = torch.nn.Linear(4096, 3072) self.x_embedder = torch.nn.Linear(64, 3072) @@ -428,12 +473,12 @@ class FluxDiT(torch.nn.Module): height, width = hidden_states.shape[-2:] hidden_states = self.patchify(hidden_states) - hidden_states = self.x_embedder(hidden_states) + hidden_states = self.x_embedder(hidden_states,**kwargs) if entity_prompt_emb is not None and entity_masks is not None: prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) else: - prompt_emb = self.context_embedder(prompt_emb) + prompt_emb = self.context_embedder(prompt_emb, **kwargs) image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) attention_mask = None @@ -446,26 +491,26 @@ class FluxDiT(torch.nn.Module): if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) for block in self.single_blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs) hidden_states = hidden_states[:, prompt_emb.shape[1]:] - hidden_states = self.final_norm_out(hidden_states, conditioning) - hidden_states = self.final_proj_out(hidden_states) + hidden_states = self.final_norm_out(hidden_states, conditioning, **kwargs) + hidden_states = self.final_proj_out(hidden_states, **kwargs) hidden_states = self.unpatchify(hidden_states, height, width) return hidden_states @@ -606,6 +651,10 @@ class FluxDiTStateDictConverter: for name, param in state_dict.items(): if name.endswith(".weight") or name.endswith(".bias"): suffix = ".weight" if name.endswith(".weight") else ".bias" + if "lora_B" in name: + suffix = ".lora_B" + suffix + if "lora_A" in name: + suffix = ".lora_A" + suffix prefix = name[:-len(suffix)] if prefix in global_rename_dict: state_dict_[global_rename_dict[prefix] + suffix] = param @@ -630,29 +679,73 @@ class FluxDiTStateDictConverter: for name in list(state_dict_.keys()): if "single_blocks." in name and ".a_to_q." in name: mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: - mlp = torch.zeros(4 * state_dict_[name].shape[0], + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], *state_dict_[name].shape[1:], dtype=state_dict_[name].dtype) else: + # print('$$'*10) + # mlp_name = name.replace(".a_to_q.", ".proj_in_besides_attn.") + # print(f'mlp name: {mlp_name}') + # print(f'mlp shape: {mlp.shape}') state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) - param = torch.concat([ - state_dict_.pop(name), - state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), - state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), - mlp, - ], dim=0) + # print(f'mlp shape: {mlp.shape}') + if 'lora_A' in name: + + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + # create zreo matrix + d, r = state_dict_[name].shape + # print('--'*10) + # print(d, r) + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") state_dict_[name_] = param for name in list(state_dict_.keys()): for component in ["a", "b"]: if f".{component}_to_q." in name: name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") - param = torch.concat([ - state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], - state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], - state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], - ], dim=0) + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) state_dict_[name_] = param state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) @@ -718,22 +811,48 @@ class FluxDiTStateDictConverter: "norm.query_norm.scale": "norm_q_a.weight", } state_dict_ = {} + + for name, param in state_dict.items(): + # match lora load + l_name = '' + if 'lora_A' in name : + l_name = 'lora_A' + if 'lora_B' in name : + l_name = 'lora_B' + if l_name != '': + name = name.replace(l_name+'.', '') + + if name.startswith("model.diffusion_model."): name = name[len("model.diffusion_model."):] names = name.split(".") if name in rename_dict: rename = rename_dict[name] if name.startswith("final_layer.adaLN_modulation.1."): - param = torch.concat([param[3072:], param[:3072]], dim=0) - state_dict_[rename] = param + if l_name == 'lora_A': + param = torch.concat([param[:,3072:], param[:,:3072]], dim=1) + else: + param = torch.concat([param[3072:], param[:3072]], dim=0) + if l_name != '': + state_dict_[rename.replace('weight',l_name+'.weight')] = param + else: + state_dict_[rename] = param + elif names[0] == "double_blocks": rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] - state_dict_[rename] = param + if l_name != '': + state_dict_[rename.replace('weight',l_name+'.weight')] = param + else: + state_dict_[rename] = param + elif names[0] == "single_blocks": if ".".join(names[2:]) in suffix_rename_dict: rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] - state_dict_[rename] = param + if l_name != '': + state_dict_[rename.replace('weight',l_name+'.weight')] = param + else: + state_dict_[rename] = param else: pass if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_: diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 7d4f52d..c90198e 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -26,6 +26,12 @@ class LoRAFromCivitai: return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha) return self.convert_state_dict_AB(state_dict, lora_prefix, alpha) + def convert_state_name(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): + for key in state_dict: + if ".lora_up" in key: + return self.convert_state_name_up_down(state_dict, lora_prefix, alpha) + return self.convert_state_name_AB(state_dict, lora_prefix, alpha) + def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "") @@ -50,13 +56,37 @@ class LoRAFromCivitai: return state_dict_ + def convert_state_name_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): + renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "") + state_dict_ = {} + for key in state_dict: + if ".lora_up" not in key: + continue + if not key.startswith(lora_prefix): + continue + weight_up = state_dict[key].to(device="cuda", dtype=torch.float16) + weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) + target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight" + for special_key in self.special_keys: + target_name = target_name.replace(special_key, self.special_keys[special_key]) + + state_dict_[target_name.replace(".weight",".lora_B.weight")] = weight_up.cpu() + state_dict_[target_name.replace(".weight",".lora_A.weight")] = weight_down.cpu() + return state_dict_ + + def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16): state_dict_ = {} + for key in state_dict: if ".lora_B." not in key: continue if not key.startswith(lora_prefix): continue + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) if len(weight_up.shape) == 4: @@ -67,11 +97,39 @@ class LoRAFromCivitai: lora_weight = alpha * torch.mm(weight_up, weight_down) keys = key.split(".") keys.pop(keys.index("lora_B")) + target_name = ".".join(keys) + target_name = target_name[len(lora_prefix):] + state_dict_[target_name] = lora_weight.cpu() return state_dict_ + def convert_state_name_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16): + state_dict_ = {} + + for key in state_dict: + if ".lora_B." not in key: + continue + if not key.startswith(lora_prefix): + continue + + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + + keys = key.split(".") + keys.pop(keys.index("lora_B")) + + target_name = ".".join(keys) + + target_name = target_name[len(lora_prefix):] + + state_dict_[target_name.replace(".weight",".lora_B.weight")] = weight_up.cpu() + state_dict_[target_name.replace(".weight",".lora_A.weight")] = weight_down.cpu() + return state_dict_ def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None): state_dict_model = model.state_dict() @@ -100,13 +158,16 @@ class LoRAFromCivitai: for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes): if not isinstance(model, model_class): continue + # print(f'lora_prefix: {lora_prefix}') state_dict_model = model.state_dict() for model_resource in ["diffusers", "civitai"]: try: state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0) + # print(f'after convert_state_dict lora state_dict:{state_dict_lora_.keys()}') converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \ else model.__class__.state_dict_converter().from_civitai state_dict_lora_ = converter_fn(state_dict_lora_) + # print(f'after converter_fn lora state_dict:{state_dict_lora_.keys()}') if isinstance(state_dict_lora_, tuple): state_dict_lora_ = state_dict_lora_[0] if len(state_dict_lora_) == 0: @@ -120,7 +181,35 @@ class LoRAFromCivitai: pass return None + def get_converted_lora_state_dict(self, model, state_dict_lora): + for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes): + if not isinstance(model, model_class): + continue + state_dict_model = model.state_dict() + for model_resource in ["diffusers","civitai"]: + try: + state_dict_lora_ = self.convert_state_name(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0) + + converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == 'diffusers' \ + else model.__class__.state_dict_converter().from_civitai + state_dict_lora_ = converter_fn(state_dict_lora_) + + if isinstance(state_dict_lora_, tuple): + state_dict_lora_ = state_dict_lora_[0] + + if len(state_dict_lora_) == 0: + continue + # return state_dict_lora_ + for name in state_dict_lora_: + if name.replace('.lora_B','').replace('.lora_A','') not in state_dict_model: + print(f" lora's {name} is not in model.") + break + else: + return state_dict_lora_ + except Exception as e: + print(f"error {str(e)}") + return None class SDLoRAFromCivitai(LoRAFromCivitai): def __init__(self): @@ -195,73 +284,85 @@ class FluxLoRAFromCivitai(LoRAFromCivitai): "txt.mod": "txt_mod", } - - + class GeneralLoRAFromPeft: def __init__(self): self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel] - - - def get_name_dict(self, lora_state_dict): - lora_name_dict = {} - for key in lora_state_dict: + + + def fetch_device_dtype_from_state_dict(self, state_dict): + device, torch_dtype = None, None + for name, param in state_dict.items(): + device, torch_dtype = param.device, param.dtype + break + return device, torch_dtype + + + def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}): + device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict) + if torch_dtype == torch.float8_e4m3fn: + torch_dtype = torch.float32 + state_dict_ = {} + for key in state_dict: if ".lora_B." not in key: continue + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) keys = key.split(".") if len(keys) > keys.index("lora_B") + 2: keys.pop(keys.index("lora_B") + 1) keys.pop(keys.index("lora_B")) - if keys[0] == "diffusion_model": - keys.pop(0) target_name = ".".join(keys) - lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) - return lora_name_dict + if target_name.startswith("diffusion_model."): + target_name = target_name[len("diffusion_model."):] + if target_name not in target_state_dict: + return {} + state_dict_[target_name] = lora_weight.cpu() + return state_dict_ - - def match(self, model: torch.nn.Module, state_dict_lora): - lora_name_dict = self.get_name_dict(state_dict_lora) - model_name_dict = {name: None for name, _ in model.named_parameters()} - matched_num = sum([i in model_name_dict for i in lora_name_dict]) - if matched_num == len(lora_name_dict): - return "", "" - else: - return None - - - def fetch_device_and_dtype(self, state_dict): - device, dtype = None, None - for name, param in state_dict.items(): - device, dtype = param.device, param.dtype - break - computation_device = device - computation_dtype = dtype - if computation_device == torch.device("cpu"): - if torch.cuda.is_available(): - computation_device = torch.device("cuda") - if computation_dtype == torch.float8_e4m3fn: - computation_dtype = torch.float32 - return device, dtype, computation_device, computation_dtype - def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""): state_dict_model = model.state_dict() - device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model) - lora_name_dict = self.get_name_dict(state_dict_lora) - for name in lora_name_dict: - weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype) - weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2) - weight_down = weight_down.squeeze(3).squeeze(2) - weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - weight_lora = alpha * torch.mm(weight_up, weight_down) - weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype) - weight_patched = weight_model + weight_lora - state_dict_model[name] = weight_patched.to(device=device, dtype=dtype) - print(f" {len(lora_name_dict)} tensors are updated.") - model.load_state_dict(state_dict_model) + state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model) + if len(state_dict_lora) > 0: + print(f" {len(state_dict_lora)} tensors are updated.") + for name in state_dict_lora: + if state_dict_model[name].dtype == torch.float8_e4m3fn: + weight = state_dict_model[name].to(torch.float32) + lora_weight = state_dict_lora[name].to( + dtype=torch.float32, + device=state_dict_model[name].device + ) + state_dict_model[name] = (weight + lora_weight).to( + dtype=state_dict_model[name].dtype, + device=state_dict_model[name].device + ) + else: + state_dict_model[name] += state_dict_lora[name].to( + dtype=state_dict_model[name].dtype, + device=state_dict_model[name].device + ) + model.load_state_dict(state_dict_model) + + def match(self, model, state_dict_lora): + for model_class in self.supported_model_classes: + if not isinstance(model, model_class): + continue + state_dict_model = model.state_dict() + try: + state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model) + if len(state_dict_lora_) > 0: + return "", "" + except: + pass + return None class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 7303dff..39b33f1 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -13,7 +13,7 @@ from transformers import SiglipVisionModel from copy import deepcopy from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense from ..models.flux_dit import RMSNorm -from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear +from ..vram_management import enable_vram_management, enable_auto_lora, AutoLoRALinear, AutoWrappedModule, AutoWrappedLinear class FluxImagePipeline(BasePipeline): @@ -132,6 +132,15 @@ class FluxImagePipeline(BasePipeline): ) self.enable_cpu_offload() + def enable_auto_lora(self): + enable_auto_lora( + self.dit, + module_map={ + RMSNorm: AutoWrappedModule, + torch.nn.Linear: AutoLoRALinear, + }, + name_prefix='' + ) def denoising_model(self): return self.dit @@ -391,6 +400,8 @@ class FluxImagePipeline(BasePipeline): # Progress bar progress_bar_cmd=tqdm, progress_bar_st=None, + lora_state_dicts=[], + lora_alpahs=[] ): height, width = self.check_resize_height_width(height, width) @@ -430,6 +441,8 @@ class FluxImagePipeline(BasePipeline): inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, + lora_state_dicts=lora_state_dicts, + lora_alpahs = lora_alpahs, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, ) noise_pred_posi = self.control_noise_via_local_prompts( @@ -447,6 +460,8 @@ class FluxImagePipeline(BasePipeline): noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, + lora_state_dicts=lora_state_dicts, + lora_alpahs = lora_alpahs, **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) @@ -511,7 +526,6 @@ class TeaCache: hidden_states = hidden_states + self.previous_residual return hidden_states - def lets_dance_flux( dit: FluxDiT, controlnet: FluxMultiControlNetManager = None, @@ -532,6 +546,7 @@ def lets_dance_flux( tea_cache: TeaCache = None, **kwargs ): + if tiled: def flux_forward_fn(hl, hr, wl, wr): tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None @@ -613,7 +628,8 @@ def lets_dance_flux( conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), + **kwargs ) # ControlNet if controlnet is not None and controlnet_frames is not None: @@ -629,7 +645,8 @@ def lets_dance_flux( conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + **kwargs ) # ControlNet if controlnet is not None and controlnet_frames is not None: @@ -639,8 +656,8 @@ def lets_dance_flux( if tea_cache is not None: tea_cache.store(hidden_states) - hidden_states = dit.final_norm_out(hidden_states, conditioning) - hidden_states = dit.final_proj_out(hidden_states) + hidden_states = dit.final_norm_out(hidden_states, conditioning, **kwargs) + hidden_states = dit.final_proj_out(hidden_states, **kwargs) hidden_states = dit.unpatchify(hidden_states, height, width) return hidden_states diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index a9df39e..f3d7b7c 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -70,6 +70,52 @@ class AutoWrappedLinear(torch.nn.Linear): bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) return torch.nn.functional.linear(x, weight, bias) +class AutoLoRALinear(torch.nn.Linear): + def __init__(self, name='', in_features=1, out_features=2, bias = True, device=None, dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + self.name = name + + def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], **kwargs): + out = torch.nn.functional.linear(x, self.weight, self.bias) + lora_a_name = f'{self.name}.lora_A.weight' + lora_b_name = f'{self.name}.lora_B.weight' + + for i, lora_state_dict in enumerate(lora_state_dicts): + if lora_state_dict is None: + break + if lora_a_name in lora_state_dict and lora_b_name in lora_state_dict: + lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device) + lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device) + out_lora = x @ lora_A.T @ lora_B.T + out = out + out_lora * lora_alpahs[i] + return out + +def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''): + targets = list(module_map.keys()) + for name, module in model.named_children(): + if name_prefix != '': + full_name = name_prefix + '.' + name + else: + full_name = name + if isinstance(module,targets[1]): + # print(full_name) + # print(module) + # ToDo: replace the linear to the AutoLoRALinear + new_module = AutoLoRALinear( + name=full_name, + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + device=module.weight.device, + dtype=module.weight.dtype) + new_module.weight.data.copy_(module.weight.data) + new_module.bias.data.copy_(module.bias.data) + setattr(model, name, new_module) + elif isinstance(module, targets[0]): + pass + else: + enable_auto_lora(module, module_map, full_name) + def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): for name, module in model.named_children():