mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
zimagei2l
This commit is contained in:
112
diffsynth/models/z_image_image2lora.py
Normal file
112
diffsynth/models/z_image_image2lora.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
|
||||
|
||||
|
||||
class LoRATrainerBlock(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
self.lora_patterns = lora_patterns
|
||||
self.block_id = block_id
|
||||
self.layers = []
|
||||
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
||||
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
||||
self.layers = torch.nn.ModuleList(self.layers)
|
||||
if use_residual:
|
||||
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
||||
else:
|
||||
self.proj_residual = None
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
||||
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
||||
name = lora_pattern[0]
|
||||
lora_a, lora_b = layer(x, residual=residual)
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAComponent(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
self.lora_patterns = lora_patterns
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks = []
|
||||
for lora_patterns in self.lora_patterns:
|
||||
for block_id in range(self.num_blocks):
|
||||
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
|
||||
self.blocks = torch.nn.ModuleList(self.blocks)
|
||||
self.residual_scale = 0.05
|
||||
self.use_residual = use_residual
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is not None:
|
||||
if self.use_residual:
|
||||
residual = residual * self.residual_scale
|
||||
else:
|
||||
residual = None
|
||||
lora = {}
|
||||
for block in self.blocks:
|
||||
lora.update(block(x, residual))
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAModel(torch.nn.Module):
|
||||
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
lora_patterns = [
|
||||
[
|
||||
("attention.to_q", 3840, 3840),
|
||||
("attention.to_k", 3840, 3840),
|
||||
("attention.to_v", 3840, 3840),
|
||||
("attention.to_out.0", 3840, 3840),
|
||||
],
|
||||
[
|
||||
("feed_forward.w1", 3840, 10240),
|
||||
("feed_forward.w2", 10240, 3840),
|
||||
("feed_forward.w3", 3840, 10240),
|
||||
],
|
||||
]
|
||||
config = {
|
||||
"lora_patterns": lora_patterns,
|
||||
"use_residual": use_residual,
|
||||
"compress_dim": compress_dim,
|
||||
"rank": rank,
|
||||
"residual_length": residual_length,
|
||||
"residual_mid_dim": residual_mid_dim,
|
||||
}
|
||||
self.layers_lora = ZImageImage2LoRAComponent(
|
||||
prefix="layers",
|
||||
num_blocks=30,
|
||||
**config,
|
||||
)
|
||||
self.context_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="context_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
self.noise_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="noise_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
lora.update(self.layers_lora(x, residual=residual))
|
||||
lora.update(self.context_refiner_lora(x, residual=residual))
|
||||
lora.update(self.noise_refiner_lora(x, residual=residual))
|
||||
return lora
|
||||
|
||||
def initialize_weights(self):
|
||||
state_dict = self.state_dict()
|
||||
for name in state_dict:
|
||||
if ".proj_a." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
elif ".proj_b.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0
|
||||
elif ".proj_residual.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
self.load_state_dict(state_dict)
|
||||
Reference in New Issue
Block a user