mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
support i2L
This commit is contained in:
94
diffsynth/models/dinov3_image_encoder.py
Normal file
94
diffsynth/models/dinov3_image_encoder.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||
import torch
|
||||
|
||||
|
||||
class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||
def __init__(self):
|
||||
config = DINOv3ViTConfig(
|
||||
architectures = [
|
||||
"DINOv3ViTModel"
|
||||
],
|
||||
attention_dropout = 0.0,
|
||||
drop_path_rate = 0.0,
|
||||
dtype = "float32",
|
||||
hidden_act = "silu",
|
||||
hidden_size = 4096,
|
||||
image_size = 224,
|
||||
initializer_range = 0.02,
|
||||
intermediate_size = 8192,
|
||||
key_bias = False,
|
||||
layer_norm_eps = 1e-05,
|
||||
layerscale_value = 1.0,
|
||||
mlp_bias = True,
|
||||
model_type = "dinov3_vit",
|
||||
num_attention_heads = 32,
|
||||
num_channels = 3,
|
||||
num_hidden_layers = 40,
|
||||
num_register_tokens = 4,
|
||||
patch_size = 16,
|
||||
pos_embed_jitter = None,
|
||||
pos_embed_rescale = 2.0,
|
||||
pos_embed_shift = None,
|
||||
proj_bias = True,
|
||||
query_bias = False,
|
||||
rope_theta = 100.0,
|
||||
transformers_version = "4.56.1",
|
||||
use_gated_mlp = True,
|
||||
value_bias = False
|
||||
)
|
||||
super().__init__(config)
|
||||
self.processor = DINOv3ViTImageProcessorFast(
|
||||
crop_size = None,
|
||||
data_format = "channels_first",
|
||||
default_to_square = True,
|
||||
device = None,
|
||||
disable_grouping = None,
|
||||
do_center_crop = None,
|
||||
do_convert_rgb = None,
|
||||
do_normalize = True,
|
||||
do_rescale = True,
|
||||
do_resize = True,
|
||||
image_mean = [
|
||||
0.485,
|
||||
0.456,
|
||||
0.406
|
||||
],
|
||||
image_processor_type = "DINOv3ViTImageProcessorFast",
|
||||
image_std = [
|
||||
0.229,
|
||||
0.224,
|
||||
0.225
|
||||
],
|
||||
input_data_format = None,
|
||||
resample = 2,
|
||||
rescale_factor = 0.00392156862745098,
|
||||
return_tensors = None,
|
||||
size = {
|
||||
"height": 224,
|
||||
"width": 224
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
||||
bool_masked_pos = None
|
||||
head_mask = None
|
||||
|
||||
pixel_values = pixel_values.to(torch_dtype)
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
sequence_output = self.norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
return pooled_output
|
||||
128
diffsynth/models/qwen_image_image2lora.py
Normal file
128
diffsynth/models/qwen_image_image2lora.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
|
||||
|
||||
class CompressedMLP(torch.nn.Module):
|
||||
def __init__(self, in_dim, mid_dim, out_dim, bias=False):
|
||||
super().__init__()
|
||||
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
|
||||
self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
x = self.proj_in(x)
|
||||
if residual is not None: x = x + residual
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class ImageEmbeddingToLoraMatrix(torch.nn.Module):
|
||||
def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
|
||||
super().__init__()
|
||||
self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
|
||||
self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
|
||||
self.lora_a_dim = lora_a_dim
|
||||
self.lora_b_dim = lora_b_dim
|
||||
self.rank = rank
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim)
|
||||
lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank)
|
||||
return lora_a, lora_b
|
||||
|
||||
|
||||
class SequencialMLP(torch.nn.Module):
|
||||
def __init__(self, length, in_dim, mid_dim, out_dim, bias=False):
|
||||
super().__init__()
|
||||
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
|
||||
self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias)
|
||||
self.length = length
|
||||
self.in_dim = in_dim
|
||||
self.mid_dim = mid_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(self.length, self.in_dim)
|
||||
x = self.proj_in(x)
|
||||
x = x.view(1, self.length * self.mid_dim)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class LoRATrainerBlock(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
self.lora_patterns = lora_patterns
|
||||
self.block_id = block_id
|
||||
self.layers = []
|
||||
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
||||
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
||||
self.layers = torch.nn.ModuleList(self.layers)
|
||||
if use_residual:
|
||||
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
||||
else:
|
||||
self.proj_residual = None
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
||||
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
||||
name = lora_pattern[0]
|
||||
lora_a, lora_b = layer(x, residual=residual)
|
||||
lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
||||
lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
||||
return lora
|
||||
|
||||
|
||||
class QwenImageImage2LoRAModel(torch.nn.Module):
|
||||
def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
self.lora_patterns = [
|
||||
[
|
||||
("attn.to_q", 3072, 3072),
|
||||
("attn.to_k", 3072, 3072),
|
||||
("attn.to_v", 3072, 3072),
|
||||
("attn.to_out.0", 3072, 3072),
|
||||
],
|
||||
[
|
||||
("img_mlp.net.2", 3072*4, 3072),
|
||||
("img_mod.1", 3072, 3072*6),
|
||||
],
|
||||
[
|
||||
("attn.add_q_proj", 3072, 3072),
|
||||
("attn.add_k_proj", 3072, 3072),
|
||||
("attn.add_v_proj", 3072, 3072),
|
||||
("attn.to_add_out", 3072, 3072),
|
||||
],
|
||||
[
|
||||
("txt_mlp.net.2", 3072*4, 3072),
|
||||
("txt_mod.1", 3072, 3072*6),
|
||||
],
|
||||
]
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks = []
|
||||
for lora_patterns in self.lora_patterns:
|
||||
for block_id in range(self.num_blocks):
|
||||
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim))
|
||||
self.blocks = torch.nn.ModuleList(self.blocks)
|
||||
self.residual_scale = 0.05
|
||||
self.use_residual = use_residual
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is not None:
|
||||
if self.use_residual:
|
||||
residual = residual * self.residual_scale
|
||||
else:
|
||||
residual = None
|
||||
lora = {}
|
||||
for block in self.blocks:
|
||||
lora.update(block(x, residual))
|
||||
return lora
|
||||
|
||||
def initialize_weights(self):
|
||||
state_dict = self.state_dict()
|
||||
for name in state_dict:
|
||||
if ".proj_a." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
elif ".proj_b.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0
|
||||
elif ".proj_residual.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
self.load_state_dict(state_dict)
|
||||
70
diffsynth/models/siglip2_image_encoder.py
Normal file
70
diffsynth/models/siglip2_image_encoder.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
|
||||
from transformers import SiglipImageProcessor
|
||||
import torch
|
||||
|
||||
|
||||
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
def __init__(self):
|
||||
config = SiglipVisionConfig(
|
||||
attention_dropout = 0.0,
|
||||
dtype = "float32",
|
||||
hidden_act = "gelu_pytorch_tanh",
|
||||
hidden_size = 1536,
|
||||
image_size = 384,
|
||||
intermediate_size = 6144,
|
||||
layer_norm_eps = 1e-06,
|
||||
model_type = "siglip_vision_model",
|
||||
num_attention_heads = 16,
|
||||
num_channels = 3,
|
||||
num_hidden_layers = 40,
|
||||
patch_size = 16,
|
||||
transformers_version = "4.56.1",
|
||||
_attn_implementation = "sdpa"
|
||||
)
|
||||
super().__init__(config)
|
||||
self.processor = SiglipImageProcessor(
|
||||
do_convert_rgb = None,
|
||||
do_normalize = True,
|
||||
do_rescale = True,
|
||||
do_resize = True,
|
||||
image_mean = [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
image_processor_type = "SiglipImageProcessor",
|
||||
image_std = [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
processor_class = "SiglipProcessor",
|
||||
resample = 2,
|
||||
rescale_factor = 0.00392156862745098,
|
||||
size = {
|
||||
"height": 384,
|
||||
"width": 384
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
|
||||
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
|
||||
output_attentions = False
|
||||
output_hidden_states = False
|
||||
interpolate_pos_encoding = False
|
||||
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs.last_hidden_state
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooler_output = self.head(last_hidden_state) if self.use_head else None
|
||||
|
||||
return pooler_output
|
||||
Reference in New Issue
Block a user