Files
DiffSynth-Studio/diffsynth/models/z_image_image2lora.py
2026-01-07 20:36:53 +08:00

190 lines
7.5 KiB
Python

import torch
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
class LoRATrainerBlock(torch.nn.Module):
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
super().__init__()
self.prefix = prefix
self.lora_patterns = lora_patterns
self.block_id = block_id
self.layers = []
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
self.layers = torch.nn.ModuleList(self.layers)
if use_residual:
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
else:
self.proj_residual = None
def forward(self, x, residual=None):
lora = {}
if self.proj_residual is not None: residual = self.proj_residual(residual)
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
name = lora_pattern[0]
lora_a, lora_b = layer(x, residual=residual)
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
return lora
class ZImageImage2LoRAComponent(torch.nn.Module):
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = lora_patterns
self.num_blocks = num_blocks
self.blocks = []
for lora_patterns in self.lora_patterns:
for block_id in range(self.num_blocks):
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
self.blocks = torch.nn.ModuleList(self.blocks)
self.residual_scale = 0.05
self.use_residual = use_residual
def forward(self, x, residual=None):
if residual is not None:
if self.use_residual:
residual = residual * self.residual_scale
else:
residual = None
lora = {}
for block in self.blocks:
lora.update(block(x, residual))
return lora
class ZImageImage2LoRAModel(torch.nn.Module):
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
lora_patterns = [
[
("attention.to_q", 3840, 3840),
("attention.to_k", 3840, 3840),
("attention.to_v", 3840, 3840),
("attention.to_out.0", 3840, 3840),
],
[
("feed_forward.w1", 3840, 10240),
("feed_forward.w2", 10240, 3840),
("feed_forward.w3", 3840, 10240),
],
]
config = {
"lora_patterns": lora_patterns,
"use_residual": use_residual,
"compress_dim": compress_dim,
"rank": rank,
"residual_length": residual_length,
"residual_mid_dim": residual_mid_dim,
}
self.layers_lora = ZImageImage2LoRAComponent(
prefix="layers",
num_blocks=30,
**config,
)
self.context_refiner_lora = ZImageImage2LoRAComponent(
prefix="context_refiner",
num_blocks=2,
**config,
)
self.noise_refiner_lora = ZImageImage2LoRAComponent(
prefix="noise_refiner",
num_blocks=2,
**config,
)
def forward(self, x, residual=None):
lora = {}
lora.update(self.layers_lora(x, residual=residual))
lora.update(self.context_refiner_lora(x, residual=residual))
lora.update(self.noise_refiner_lora(x, residual=residual))
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if ".proj_a." in name:
state_dict[name] = state_dict[name] * 0.3
elif ".proj_b.proj_out." in name:
state_dict[name] = state_dict[name] * 0
elif ".proj_residual.proj_out." in name:
state_dict[name] = state_dict[name] * 0.3
self.load_state_dict(state_dict)
class ImageEmb2LoRAWeightCompressed(torch.nn.Module):
def __init__(self, in_dim, out_dim, emb_dim, rank):
super().__init__()
self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim)))
self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank)))
self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True)
self.rank = rank
def forward(self, x):
x = self.proj(x).view(self.rank, self.rank)
lora_a = x @ self.lora_a
lora_b = self.lora_b
return lora_a, lora_b
class ZImageImage2LoRAModelCompressed(torch.nn.Module):
def __init__(self, emb_dim=1536+4096, rank=32):
super().__init__()
target_layers = [
("attention.to_q", 3840, 3840),
("attention.to_k", 3840, 3840),
("attention.to_v", 3840, 3840),
("attention.to_out.0", 3840, 3840),
("feed_forward.w1", 3840, 10240),
("feed_forward.w2", 10240, 3840),
("feed_forward.w3", 3840, 10240),
]
self.lora_patterns = [
{
"prefix": "layers",
"num_layers": 30,
"target_layers": target_layers,
},
{
"prefix": "context_refiner",
"num_layers": 2,
"target_layers": target_layers,
},
{
"prefix": "noise_refiner",
"num_layers": 2,
"target_layers": target_layers,
},
]
module_dict = {}
for lora_pattern in self.lora_patterns:
prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"]
for layer_id in range(num_layers):
for layer_name, in_dim, out_dim in target_layers:
name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___")
model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank)
module_dict[name] = model
self.module_dict = torch.nn.ModuleDict(module_dict)
def forward(self, x, residual=None):
lora = {}
for name, module in self.module_dict.items():
name = name.replace("___", ".")
name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight"
lora_a, lora_b = module(x)
lora[name_a] = lora_a
lora[name_b] = lora_b
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if "lora_b" in name:
state_dict[name] = state_dict[name] * 0
elif "lora_a" in name:
state_dict[name] = state_dict[name] * 0.2
elif "proj.weight" in name:
print(name)
state_dict[name] = state_dict[name] * 0.2
self.load_state_dict(state_dict)