mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Support Z-Image-Omni-Base and its related models
Support Z-Image-Omni-Base and its related models.
This commit is contained in:
@@ -534,6 +534,32 @@ z_image_series = [
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||
"extra_kwargs": {"use_conv_attention": False},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
||||
"model_name": "z_image_dit",
|
||||
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||
"extra_kwargs": {"siglip_feat_dim": 1152},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
||||
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
||||
"model_name": "siglip_vision_model_428m",
|
||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
||||
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
||||
"model_name": "z_image_controlnet",
|
||||
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
||||
},
|
||||
{
|
||||
# Example: ???
|
||||
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||
"model_name": "z_image_image2lora_style",
|
||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||
"extra_kwargs": {"compress_dim": 128},
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
|
||||
@@ -195,4 +195,19 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
||||
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ class ModelConfig:
|
||||
self.reset_local_model_path()
|
||||
if self.require_downloading():
|
||||
self.download()
|
||||
if self.path is None:
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||
else:
|
||||
|
||||
@@ -235,6 +235,7 @@ class BasePipeline(torch.nn.Module):
|
||||
alpha=1,
|
||||
hotload=None,
|
||||
state_dict=None,
|
||||
verbose=1,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
@@ -261,12 +262,13 @@ class BasePipeline(torch.nn.Module):
|
||||
updated_num += 1
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||
if verbose >= 1:
|
||||
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||
else:
|
||||
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self):
|
||||
def clear_lora(self, verbose=1):
|
||||
cleared_num = 0
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
@@ -276,7 +278,8 @@ class BasePipeline(torch.nn.Module):
|
||||
module.lora_A_weights.clear()
|
||||
if hasattr(module, "lora_B_weights"):
|
||||
module.lora_B_weights.clear()
|
||||
print(f"{cleared_num} LoRA layers are cleared.")
|
||||
if verbose >= 1:
|
||||
print(f"{cleared_num} LoRA layers are cleared.")
|
||||
|
||||
|
||||
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||
@@ -304,8 +307,13 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
|
||||
from transformers import SiglipImageProcessor
|
||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
||||
import torch
|
||||
|
||||
|
||||
@@ -68,3 +68,65 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
pooler_output = self.head(last_hidden_state) if self.use_head else None
|
||||
|
||||
return pooler_output
|
||||
|
||||
|
||||
class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
||||
def __init__(self):
|
||||
config = Siglip2VisionConfig(
|
||||
attention_dropout = 0.0,
|
||||
dtype = "bfloat16",
|
||||
hidden_act = "gelu_pytorch_tanh",
|
||||
hidden_size = 1152,
|
||||
intermediate_size = 4304,
|
||||
layer_norm_eps = 1e-06,
|
||||
model_type = "siglip2_vision_model",
|
||||
num_attention_heads = 16,
|
||||
num_channels = 3,
|
||||
num_hidden_layers = 27,
|
||||
num_patches = 256,
|
||||
patch_size = 16,
|
||||
transformers_version = "4.57.1"
|
||||
)
|
||||
super().__init__(config)
|
||||
self.processor = Siglip2ImageProcessorFast(
|
||||
**{
|
||||
"data_format": "channels_first",
|
||||
"default_to_square": True,
|
||||
"device": None,
|
||||
"disable_grouping": None,
|
||||
"do_convert_rgb": None,
|
||||
"do_normalize": True,
|
||||
"do_pad": None,
|
||||
"do_rescale": True,
|
||||
"do_resize": True,
|
||||
"image_mean": [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
"image_processor_type": "Siglip2ImageProcessorFast",
|
||||
"image_std": [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
"input_data_format": None,
|
||||
"max_num_patches": 256,
|
||||
"pad_size": None,
|
||||
"patch_size": 16,
|
||||
"processor_class": "Siglip2Processor",
|
||||
"resample": 2,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"return_tensors": None,
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
|
||||
shape = siglip_inputs.spatial_shapes[0]
|
||||
hidden_state = super().forward(**siglip_inputs).last_hidden_state
|
||||
B, N, C = hidden_state.shape
|
||||
hidden_state = hidden_state[:, : shape[0] * shape[1]]
|
||||
hidden_state = hidden_state.view(shape[0], shape[1], C)
|
||||
hidden_state = hidden_state.to(torch_dtype)
|
||||
return hidden_state
|
||||
|
||||
154
diffsynth/models/z_image_controlnet.py
Normal file
154
diffsynth/models/z_image_controlnet.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from .z_image_dit import ZImageTransformerBlock
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ZImageControlTransformerBlock(ZImageTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int = 1000,
|
||||
dim: int = 3840,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int = 30,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
modulation = True,
|
||||
block_id = 0
|
||||
):
|
||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||
self.after_proj = nn.Linear(self.dim, self.dim)
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class ZImageControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
||||
control_in_dim=33,
|
||||
dim=3840,
|
||||
n_refiner_layers=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
|
||||
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
|
||||
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
|
||||
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
|
||||
|
||||
def forward_layers(
|
||||
self,
|
||||
x,
|
||||
cap_feats,
|
||||
control_context,
|
||||
control_context_item_seqlens,
|
||||
kwargs,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
bsz = len(control_context)
|
||||
# unified
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
control_context_unified = []
|
||||
for i in range(bsz):
|
||||
control_context_len = control_context_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
|
||||
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
for layer in self.control_layers:
|
||||
c = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
c=c, **new_kwargs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
def forward_refiner(
|
||||
self,
|
||||
dit,
|
||||
x,
|
||||
cap_feats,
|
||||
control_context,
|
||||
kwargs,
|
||||
t=None,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
# embeddings
|
||||
bsz = len(control_context)
|
||||
device = control_context[0].device
|
||||
(
|
||||
control_context,
|
||||
control_context_size,
|
||||
control_context_pos_ids,
|
||||
control_context_inner_pad_mask,
|
||||
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
|
||||
|
||||
# control_context embed & refine
|
||||
control_context_item_seqlens = [len(_) for _ in control_context]
|
||||
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
|
||||
control_context_max_item_seqlen = max(control_context_item_seqlens)
|
||||
|
||||
control_context = torch.cat(control_context, dim=0)
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
|
||||
|
||||
# Match t_embedder output dtype to control_context for layerwise casting compatibility
|
||||
adaln_input = t.type_as(control_context)
|
||||
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
|
||||
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
|
||||
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
|
||||
|
||||
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
|
||||
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(control_context_item_seqlens):
|
||||
control_context_attn_mask[i, :seq_len] = 1
|
||||
c = control_context
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(
|
||||
x=x,
|
||||
attn_mask=control_context_attn_mask,
|
||||
freqs_cis=control_context_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
for layer in self.control_noise_refiner:
|
||||
c = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
c=c, **new_kwargs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
control_context = torch.unbind(c)[-1]
|
||||
|
||||
return hints, control_context, control_context_item_seqlens
|
||||
@@ -13,6 +13,7 @@ from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
X_PAD_DIM = 64
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
@@ -86,7 +87,7 @@ class Attention(torch.nn.Module):
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||
|
||||
def forward(self, hidden_states, freqs_cis):
|
||||
def forward(self, hidden_states, freqs_cis, attention_mask):
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
@@ -123,6 +124,7 @@ class Attention(torch.nn.Module):
|
||||
key,
|
||||
value,
|
||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||
attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
@@ -136,6 +138,20 @@ class Attention(torch.nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -180,40 +196,53 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
freqs_cis=freqs_cis,
|
||||
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x) * scale_mlp,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||
|
||||
return x
|
||||
|
||||
@@ -229,9 +258,21 @@ class FinalLayer(nn.Module):
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation
|
||||
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
|
||||
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
|
||||
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Original global modulation
|
||||
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
scale = scale.unsqueeze(1)
|
||||
|
||||
x = self.norm_final(x) * scale
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -299,6 +340,7 @@ class ZImageDiT(nn.Module):
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
axes_lens=[1024, 512, 512],
|
||||
siglip_feat_dim=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@@ -359,6 +401,32 @@ class ZImageDiT(nn.Module):
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
|
||||
# Optional SigLIP components (for Omni variant)
|
||||
self.siglip_feat_dim = siglip_feat_dim
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
2000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
@@ -375,22 +443,57 @@ class ZImageDiT(nn.Module):
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
def unpatchify(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
size: List[Tuple],
|
||||
patch_size = 2,
|
||||
f_patch_size = 1,
|
||||
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
if x_pos_offsets is not None:
|
||||
# Omni: extract target image from unified sequence (cond_images + target)
|
||||
result = []
|
||||
for i in range(bsz):
|
||||
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
|
||||
cu_len = 0
|
||||
x_item = None
|
||||
for j in range(len(size[i])):
|
||||
if size[i][j] is None:
|
||||
ori_len = 0
|
||||
pad_len = SEQ_MULTI_OF
|
||||
cu_len += pad_len + ori_len
|
||||
else:
|
||||
F, H, W = size[i][j]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
x_item = (
|
||||
unified_x[cu_len : cu_len + ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
cu_len += ori_len + pad_len
|
||||
result.append(x_item) # Return only the last (target) image
|
||||
return result
|
||||
else:
|
||||
# Original mode: simple unpatchify
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
@@ -405,8 +508,8 @@ class ZImageDiT(nn.Module):
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
@@ -490,90 +593,487 @@ class ZImageDiT(nn.Module):
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return all_image_out, all_cap_feats_out, {
|
||||
"x_size": all_image_size,
|
||||
"x_pos_ids": all_image_pos_ids,
|
||||
"cap_pos_ids": all_cap_pos_ids,
|
||||
"x_pad_mask": all_image_pad_mask,
|
||||
"cap_pad_mask": all_cap_pad_mask
|
||||
}
|
||||
# (
|
||||
# all_img_out,
|
||||
# all_cap_out,
|
||||
# all_img_size,
|
||||
# all_img_pos_ids,
|
||||
# all_cap_pos_ids,
|
||||
# all_img_pad_mask,
|
||||
# all_cap_pad_mask,
|
||||
# )
|
||||
|
||||
def patchify_controlnet(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
cap_padding_len: int = None,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
device = all_image[0].device
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
|
||||
for i, image in enumerate(all_image):
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padding_pos_ids = (
|
||||
self.create_coordinate_grid(
|
||||
size=(1, 1, 1),
|
||||
start=(0, 0, 0),
|
||||
device=device,
|
||||
)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1)
|
||||
)
|
||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||
all_image_pos_ids.append(image_padded_pos_ids)
|
||||
# pad mask
|
||||
all_image_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def _prepare_sequence(
|
||||
self,
|
||||
feats: List[torch.Tensor],
|
||||
pos_ids: List[torch.Tensor],
|
||||
inner_pad_mask: List[torch.Tensor],
|
||||
pad_token: torch.nn.Parameter,
|
||||
noise_mask: Optional[List[List[int]]] = None,
|
||||
device: torch.device = None,
|
||||
):
|
||||
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
|
||||
item_seqlens = [len(f) for f in feats]
|
||||
max_seqlen = max(item_seqlens)
|
||||
bsz = len(feats)
|
||||
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
# RoPE
|
||||
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
|
||||
|
||||
# Pad to batch
|
||||
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if noise_mask is not None:
|
||||
noise_mask_tensor = pad_sequence(
|
||||
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)[:, : feats.shape[1]]
|
||||
|
||||
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
|
||||
|
||||
def _build_unified_sequence(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_freqs: torch.Tensor,
|
||||
x_seqlens: List[int],
|
||||
x_noise_mask: Optional[List[List[int]]],
|
||||
cap: torch.Tensor,
|
||||
cap_freqs: torch.Tensor,
|
||||
cap_seqlens: List[int],
|
||||
cap_noise_mask: Optional[List[List[int]]],
|
||||
siglip: Optional[torch.Tensor],
|
||||
siglip_freqs: Optional[torch.Tensor],
|
||||
siglip_seqlens: Optional[List[int]],
|
||||
siglip_noise_mask: Optional[List[List[int]]],
|
||||
omni_mode: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Build unified sequence: x, cap, and optionally siglip.
|
||||
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
|
||||
"""
|
||||
bsz = len(x_seqlens)
|
||||
unified = []
|
||||
unified_freqs = []
|
||||
unified_noise_mask = []
|
||||
|
||||
for i in range(bsz):
|
||||
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
|
||||
|
||||
if omni_mode:
|
||||
# Omni: [cap, x, siglip]
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
sig_len = siglip_seqlens[i]
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
|
||||
unified_freqs.append(
|
||||
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
|
||||
)
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(
|
||||
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
|
||||
)
|
||||
)
|
||||
else:
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
|
||||
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
|
||||
)
|
||||
else:
|
||||
# Basic: [x, cap]
|
||||
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
|
||||
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
|
||||
|
||||
# Compute unified seqlens
|
||||
if omni_mode:
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
|
||||
else:
|
||||
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
|
||||
else:
|
||||
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
|
||||
|
||||
max_seqlen = max(unified_seqlens)
|
||||
|
||||
# Pad to batch
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if omni_mode:
|
||||
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
|
||||
:, : unified.shape[1]
|
||||
]
|
||||
|
||||
return unified, unified_freqs, attn_mask, noise_mask_tensor
|
||||
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
def patchify_and_embed_omni(
|
||||
self,
|
||||
all_x: List[List[torch.Tensor]],
|
||||
all_cap_feats: List[List[torch.Tensor]],
|
||||
all_siglip_feats: List[List[torch.Tensor]],
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
images_noise_mask: List[List[int]] = None,
|
||||
):
|
||||
"""Patchify for omni mode: multiple images per batch item with noise masks."""
|
||||
bsz = len(all_x)
|
||||
device = all_x[0][-1].device
|
||||
dtype = all_x[0][-1].dtype
|
||||
|
||||
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
|
||||
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
|
||||
|
||||
for i in range(bsz):
|
||||
num_images = len(all_x[i])
|
||||
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
|
||||
cap_end_pos = []
|
||||
cap_cu_len = 1
|
||||
|
||||
# Process captions
|
||||
for j, cap_item in enumerate(all_cap_feats[i]):
|
||||
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
|
||||
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
|
||||
cap_item,
|
||||
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
|
||||
(cap_cu_len, 0, 0),
|
||||
device,
|
||||
noise_val,
|
||||
)
|
||||
cap_feats_list.append(cap_out)
|
||||
cap_pos_list.append(cap_pos)
|
||||
cap_mask_list.append(cap_mask)
|
||||
cap_lens.append(cap_len)
|
||||
cap_noise.extend(cap_nm)
|
||||
cap_cu_len += len(cap_item)
|
||||
cap_end_pos.append(cap_cu_len)
|
||||
cap_cu_len += 2 # for image vae and siglip tokens
|
||||
|
||||
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
|
||||
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
|
||||
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
|
||||
all_cap_len.append(cap_lens)
|
||||
all_cap_noise_mask.append(cap_noise)
|
||||
|
||||
# Process images
|
||||
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
|
||||
for j, x_item in enumerate(all_x[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if x_item is not None:
|
||||
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
|
||||
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
|
||||
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
|
||||
)
|
||||
x_size.append(size)
|
||||
else:
|
||||
x_len = SEQ_MULTI_OF
|
||||
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
|
||||
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
|
||||
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
|
||||
x_nm = [noise_val] * x_len
|
||||
x_size.append(None)
|
||||
x_feats_list.append(x_out)
|
||||
x_pos_list.append(x_pos)
|
||||
x_mask_list.append(x_mask)
|
||||
x_lens.append(x_len)
|
||||
x_noise.extend(x_nm)
|
||||
|
||||
all_x_out.append(torch.cat(x_feats_list, dim=0))
|
||||
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
|
||||
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
|
||||
all_x_size.append(x_size)
|
||||
all_x_len.append(x_lens)
|
||||
all_x_noise_mask.append(x_noise)
|
||||
|
||||
# Process siglip
|
||||
if all_siglip_feats[i] is None:
|
||||
all_sig_len.append([0] * num_images)
|
||||
all_sig_out.append(None)
|
||||
else:
|
||||
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
|
||||
for j, sig_item in enumerate(all_siglip_feats[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if sig_item is not None:
|
||||
sig_H, sig_W, sig_C = sig_item.size()
|
||||
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
|
||||
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
|
||||
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
|
||||
)
|
||||
# Scale position IDs to match x resolution
|
||||
if x_size[j] is not None:
|
||||
sig_pos = sig_pos.float()
|
||||
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
|
||||
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
|
||||
sig_pos = sig_pos.to(torch.int32)
|
||||
else:
|
||||
sig_len = SEQ_MULTI_OF
|
||||
sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)
|
||||
sig_pos = (
|
||||
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
|
||||
)
|
||||
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
|
||||
sig_nm = [noise_val] * sig_len
|
||||
sig_feats_list.append(sig_out)
|
||||
sig_pos_list.append(sig_pos)
|
||||
sig_mask_list.append(sig_mask)
|
||||
sig_lens.append(sig_len)
|
||||
sig_noise.extend(sig_nm)
|
||||
|
||||
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
|
||||
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
|
||||
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
|
||||
all_sig_len.append(sig_lens)
|
||||
all_sig_noise_mask.append(sig_noise)
|
||||
|
||||
# Compute x position offsets
|
||||
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
|
||||
|
||||
return (
|
||||
all_x_out,
|
||||
all_cap_out,
|
||||
all_sig_out,
|
||||
all_x_size,
|
||||
all_x_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_sig_pos_ids,
|
||||
all_x_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
all_sig_pad_mask,
|
||||
all_x_pos_offsets,
|
||||
all_x_noise_mask,
|
||||
all_cap_noise_mask,
|
||||
all_sig_noise_mask,
|
||||
)
|
||||
return all_x_out, all_cap_out, all_sig_out, {
|
||||
"x_size": x_size,
|
||||
"x_pos_ids": all_x_pos_ids,
|
||||
"cap_pos_ids": all_cap_pos_ids,
|
||||
"sig_pos_ids": all_sig_pos_ids,
|
||||
"x_pad_mask": all_x_pad_mask,
|
||||
"cap_pad_mask": all_cap_pad_mask,
|
||||
"sig_pad_mask": all_sig_pad_mask,
|
||||
"x_pos_offsets": all_x_pos_offsets,
|
||||
"x_noise_mask": all_x_noise_mask,
|
||||
"cap_noise_mask": all_cap_noise_mask,
|
||||
"sig_noise_mask": all_sig_noise_mask,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
siglip_feats = None,
|
||||
image_noise_mask = None,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
|
||||
omni_mode = isinstance(x[0], list)
|
||||
device = x[0][-1].device if omni_mode else x[0].device
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
if omni_mode:
|
||||
# Dual embeddings: noisy (t) and clean (t=1)
|
||||
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
|
||||
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
|
||||
adaln_input = None
|
||||
else:
|
||||
# Single embedding for all tokens
|
||||
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
|
||||
t_noisy = t_clean = None
|
||||
|
||||
adaln_input = t
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
# Patchify
|
||||
if omni_mode:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
siglip_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
siglip_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
siglip_pad_mask,
|
||||
x_pos_offsets,
|
||||
x_noise_mask,
|
||||
cap_noise_mask,
|
||||
siglip_noise_mask,
|
||||
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
|
||||
else:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
x_seqlens = [len(xi) for xi in x]
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
|
||||
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
|
||||
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
|
||||
)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=x,
|
||||
attn_mask=x_attn_mask,
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean,
|
||||
)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
# Cap embed & refine
|
||||
cap_seqlens = [len(ci) for ci in cap_feats]
|
||||
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
|
||||
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
|
||||
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = gradient_checkpoint_forward(
|
||||
@@ -581,41 +1081,68 @@ class ZImageDiT(nn.Module):
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=cap_feats,
|
||||
attn_mask=cap_attn_mask,
|
||||
freqs_cis=cap_freqs_cis,
|
||||
attn_mask=cap_mask,
|
||||
freqs_cis=cap_freqs,
|
||||
)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
# Siglip embed & refine
|
||||
siglip_seqlens = siglip_freqs = None
|
||||
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
|
||||
siglip_seqlens = [len(si) for si in siglip_feats]
|
||||
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
|
||||
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
|
||||
list(siglip_feats.split(siglip_seqlens, dim=0)),
|
||||
siglip_pos_ids,
|
||||
siglip_pad_mask,
|
||||
self.siglip_pad_token,
|
||||
None,
|
||||
device,
|
||||
)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs,
|
||||
)
|
||||
|
||||
for layer in self.layers:
|
||||
# Unified sequence
|
||||
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
|
||||
x,
|
||||
x_freqs,
|
||||
x_seqlens,
|
||||
x_noise_mask,
|
||||
cap_feats,
|
||||
cap_freqs,
|
||||
cap_seqlens,
|
||||
cap_noise_mask,
|
||||
siglip_feats,
|
||||
siglip_freqs,
|
||||
siglip_seqlens,
|
||||
siglip_noise_mask,
|
||||
omni_mode,
|
||||
device,
|
||||
)
|
||||
|
||||
# Main transformer layers
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=unified,
|
||||
attn_mask=unified_attn_mask,
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean
|
||||
)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
unified = (
|
||||
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
|
||||
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
|
||||
)
|
||||
if omni_mode
|
||||
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
|
||||
)
|
||||
|
||||
return x, {}
|
||||
# Unpatchify
|
||||
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
|
||||
|
||||
return x
|
||||
|
||||
189
diffsynth/models/z_image_image2lora.py
Normal file
189
diffsynth/models/z_image_image2lora.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import torch
|
||||
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
|
||||
|
||||
|
||||
class LoRATrainerBlock(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
self.lora_patterns = lora_patterns
|
||||
self.block_id = block_id
|
||||
self.layers = []
|
||||
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
||||
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
||||
self.layers = torch.nn.ModuleList(self.layers)
|
||||
if use_residual:
|
||||
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
||||
else:
|
||||
self.proj_residual = None
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
||||
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
||||
name = lora_pattern[0]
|
||||
lora_a, lora_b = layer(x, residual=residual)
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAComponent(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
self.lora_patterns = lora_patterns
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks = []
|
||||
for lora_patterns in self.lora_patterns:
|
||||
for block_id in range(self.num_blocks):
|
||||
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
|
||||
self.blocks = torch.nn.ModuleList(self.blocks)
|
||||
self.residual_scale = 0.05
|
||||
self.use_residual = use_residual
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is not None:
|
||||
if self.use_residual:
|
||||
residual = residual * self.residual_scale
|
||||
else:
|
||||
residual = None
|
||||
lora = {}
|
||||
for block in self.blocks:
|
||||
lora.update(block(x, residual))
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAModel(torch.nn.Module):
|
||||
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
lora_patterns = [
|
||||
[
|
||||
("attention.to_q", 3840, 3840),
|
||||
("attention.to_k", 3840, 3840),
|
||||
("attention.to_v", 3840, 3840),
|
||||
("attention.to_out.0", 3840, 3840),
|
||||
],
|
||||
[
|
||||
("feed_forward.w1", 3840, 10240),
|
||||
("feed_forward.w2", 10240, 3840),
|
||||
("feed_forward.w3", 3840, 10240),
|
||||
],
|
||||
]
|
||||
config = {
|
||||
"lora_patterns": lora_patterns,
|
||||
"use_residual": use_residual,
|
||||
"compress_dim": compress_dim,
|
||||
"rank": rank,
|
||||
"residual_length": residual_length,
|
||||
"residual_mid_dim": residual_mid_dim,
|
||||
}
|
||||
self.layers_lora = ZImageImage2LoRAComponent(
|
||||
prefix="layers",
|
||||
num_blocks=30,
|
||||
**config,
|
||||
)
|
||||
self.context_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="context_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
self.noise_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="noise_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
lora.update(self.layers_lora(x, residual=residual))
|
||||
lora.update(self.context_refiner_lora(x, residual=residual))
|
||||
lora.update(self.noise_refiner_lora(x, residual=residual))
|
||||
return lora
|
||||
|
||||
def initialize_weights(self):
|
||||
state_dict = self.state_dict()
|
||||
for name in state_dict:
|
||||
if ".proj_a." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
elif ".proj_b.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0
|
||||
elif ".proj_residual.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class ImageEmb2LoRAWeightCompressed(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, emb_dim, rank):
|
||||
super().__init__()
|
||||
self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim)))
|
||||
self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank)))
|
||||
self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True)
|
||||
self.rank = rank
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x).view(self.rank, self.rank)
|
||||
lora_a = x @ self.lora_a
|
||||
lora_b = self.lora_b
|
||||
return lora_a, lora_b
|
||||
|
||||
|
||||
class ZImageImage2LoRAModelCompressed(torch.nn.Module):
|
||||
def __init__(self, emb_dim=1536+4096, rank=32):
|
||||
super().__init__()
|
||||
target_layers = [
|
||||
("attention.to_q", 3840, 3840),
|
||||
("attention.to_k", 3840, 3840),
|
||||
("attention.to_v", 3840, 3840),
|
||||
("attention.to_out.0", 3840, 3840),
|
||||
("feed_forward.w1", 3840, 10240),
|
||||
("feed_forward.w2", 10240, 3840),
|
||||
("feed_forward.w3", 3840, 10240),
|
||||
]
|
||||
self.lora_patterns = [
|
||||
{
|
||||
"prefix": "layers",
|
||||
"num_layers": 30,
|
||||
"target_layers": target_layers,
|
||||
},
|
||||
{
|
||||
"prefix": "context_refiner",
|
||||
"num_layers": 2,
|
||||
"target_layers": target_layers,
|
||||
},
|
||||
{
|
||||
"prefix": "noise_refiner",
|
||||
"num_layers": 2,
|
||||
"target_layers": target_layers,
|
||||
},
|
||||
]
|
||||
module_dict = {}
|
||||
for lora_pattern in self.lora_patterns:
|
||||
prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"]
|
||||
for layer_id in range(num_layers):
|
||||
for layer_name, in_dim, out_dim in target_layers:
|
||||
name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___")
|
||||
model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank)
|
||||
module_dict[name] = model
|
||||
self.module_dict = torch.nn.ModuleDict(module_dict)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
for name, module in self.module_dict.items():
|
||||
name = name.replace("___", ".")
|
||||
name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight"
|
||||
lora_a, lora_b = module(x)
|
||||
lora[name_a] = lora_a
|
||||
lora[name_b] = lora_b
|
||||
return lora
|
||||
|
||||
def initialize_weights(self):
|
||||
state_dict = self.state_dict()
|
||||
for name in state_dict:
|
||||
if "lora_b" in name:
|
||||
state_dict[name] = state_dict[name] * 0
|
||||
elif "lora_a" in name:
|
||||
state_dict[name] = state_dict[name] * 0.2
|
||||
elif "proj.weight" in name:
|
||||
print(name)
|
||||
state_dict[name] = state_dict[name] * 0.2
|
||||
self.load_state_dict(state_dict)
|
||||
@@ -4,16 +4,23 @@ from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
from ..utils.lora import merge_lora
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
||||
from ..models.z_image_controlnet import ZImageControlNet
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
||||
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
||||
from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
@@ -28,13 +35,22 @@ class ZImagePipeline(BasePipeline):
|
||||
self.dit: ZImageDiT = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.image_encoder: Siglip2ImageEncoder428M = None
|
||||
self.controlnet: ZImageControlNet = None
|
||||
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
||||
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
||||
self.image2lora_style: ZImageImage2LoRAModel = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.units = [
|
||||
ZImageUnit_ShapeChecker(),
|
||||
ZImageUnit_PromptEmbedder(),
|
||||
ZImageUnit_NoiseInitializer(),
|
||||
ZImageUnit_InputImageEmbedder(),
|
||||
ZImageUnit_EditImageAutoResize(),
|
||||
ZImageUnit_EditImageEmbedderVAE(),
|
||||
ZImageUnit_EditImageEmbedderSiglip(),
|
||||
ZImageUnit_PAIControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_z_image
|
||||
|
||||
@@ -56,6 +72,11 @@ class ZImagePipeline(BasePipeline):
|
||||
pipe.dit = model_pool.fetch_model("z_image_dit")
|
||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
||||
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
|
||||
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
||||
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
||||
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
@@ -75,6 +96,9 @@ class ZImagePipeline(BasePipeline):
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Edit
|
||||
edit_image: Image.Image = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
@@ -83,11 +107,17 @@ class ZImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
sigma_shift: float = None,
|
||||
# ControlNet
|
||||
controlnet_inputs: List[ControlNetInput] = None,
|
||||
# Image to LoRA
|
||||
image2lora_images: List[Image.Image] = None,
|
||||
positive_only_lora: Dict[str, torch.Tensor] = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
@@ -102,6 +132,9 @@ class ZImagePipeline(BasePipeline):
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -143,12 +176,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params=("edit_image",),
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
pipe,
|
||||
@@ -194,10 +228,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def encode_prompt_omni(
|
||||
self,
|
||||
pipe,
|
||||
prompt: Union[str, List[str]],
|
||||
edit_image=None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
def process(self, pipe: ZImagePipeline, prompt):
|
||||
if edit_image is None:
|
||||
num_condition_images = 0
|
||||
elif isinstance(edit_image, list):
|
||||
num_condition_images = len(edit_image)
|
||||
else:
|
||||
num_condition_images = 1
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
if num_condition_images == 0:
|
||||
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
||||
elif num_condition_images > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
||||
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
prompt[i] = prompt_list
|
||||
|
||||
flattened_prompt = []
|
||||
prompt_list_lengths = []
|
||||
|
||||
for i in range(len(prompt)):
|
||||
prompt_list_lengths.append(len(prompt[i]))
|
||||
flattened_prompt.extend(prompt[i])
|
||||
|
||||
text_inputs = pipe.tokenizer(
|
||||
flattened_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = pipe.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
start_idx = 0
|
||||
for i in range(len(prompt_list_lengths)):
|
||||
batch_embeddings = []
|
||||
end_idx = start_idx + prompt_list_lengths[i]
|
||||
for j in range(start_idx, end_idx):
|
||||
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
||||
embeddings_list.append(batch_embeddings)
|
||||
start_idx = end_idx
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def process(self, pipe: ZImagePipeline, prompt, edit_image):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
|
||||
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
|
||||
# We determine which encoding method to use based on the model architecture.
|
||||
# If you are using two-stage split training,
|
||||
# please use `--offload_models` instead of skipping the DiT model loading.
|
||||
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
|
||||
else:
|
||||
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
@@ -234,24 +339,330 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class ZImageUnit_EditImageAutoResize(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "edit_image_auto_resize"),
|
||||
output_params=("edit_image",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
if edit_image_auto_resize is None or not edit_image_auto_resize:
|
||||
return {}
|
||||
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
edit_image = [operator(i) for i in edit_image]
|
||||
return {"edit_image": edit_image}
|
||||
|
||||
|
||||
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image",),
|
||||
output_params=("image_embeds",),
|
||||
onload_model_names=("image_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, edit_image):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
image_emb = []
|
||||
for image_ in edit_image:
|
||||
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
|
||||
return {"image_embeds": image_emb}
|
||||
|
||||
|
||||
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image",),
|
||||
output_params=("image_latents",),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, edit_image):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
image_latents = []
|
||||
for image_ in edit_image:
|
||||
image_ = pipe.preprocess_image(image_)
|
||||
image_latents.append(pipe.vae_encoder(image_))
|
||||
return {"image_latents": image_latents}
|
||||
|
||||
|
||||
class ZImageUnit_PAIControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("controlnet_inputs", "height", "width"),
|
||||
output_params=("control_context", "control_scale"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
if len(controlnet_inputs) != 1:
|
||||
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
|
||||
controlnet_input = controlnet_inputs[0]
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
control_image = controlnet_input.image
|
||||
if control_image is not None:
|
||||
control_image = pipe.preprocess_image(control_image)
|
||||
control_latents = pipe.vae_encoder(control_image)
|
||||
else:
|
||||
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
|
||||
|
||||
inpaint_mask = controlnet_input.inpaint_mask
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
|
||||
inpaint_image = controlnet_input.inpaint_image
|
||||
inpaint_image = pipe.preprocess_image(inpaint_image)
|
||||
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
|
||||
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
|
||||
else:
|
||||
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_latent = pipe.vae_encoder(inpaint_image)
|
||||
|
||||
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
|
||||
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
|
||||
return {"control_context": control_context, "control_scale": controlnet_input.scale}
|
||||
|
||||
|
||||
def model_fn_z_image(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
image_latents=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
# Due to the complex and verbose codebase of Z-Image,
|
||||
# we are temporarily using this inelegant structure.
|
||||
# We will refactor this part in the future (if time permits).
|
||||
if dit.siglip_embedder is None:
|
||||
return model_fn_z_image_turbo(
|
||||
dit,
|
||||
controlnet=controlnet,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
prompt_embeds=prompt_embeds,
|
||||
image_embeds=image_embeds,
|
||||
image_latents=image_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
**kwargs,
|
||||
)
|
||||
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||
if dit.siglip_embedder is not None:
|
||||
if image_latents is not None:
|
||||
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
|
||||
latents = [image_latents + latents]
|
||||
image_noise_mask = [[0] * len(image_latents) + [1]]
|
||||
else:
|
||||
latents = [latents]
|
||||
image_noise_mask = [[1]]
|
||||
image_embeds = [image_embeds]
|
||||
else:
|
||||
image_noise_mask = None
|
||||
timestep = (1000 - timestep) / 1000
|
||||
model_output = dit(
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
siglip_feats=image_embeds,
|
||||
image_noise_mask=image_noise_mask,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)[0][0]
|
||||
)[0]
|
||||
model_output = -model_output
|
||||
model_output = rearrange(model_output, "C B H W -> B C H W")
|
||||
return model_output
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_images",),
|
||||
output_params=("image2lora_x",),
|
||||
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
|
||||
)
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
||||
|
||||
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["siglip2_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["dinov3_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
if images is None:
|
||||
return {}
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
||||
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
||||
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
||||
return x
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_images):
|
||||
if image2lora_images is None:
|
||||
return {}
|
||||
x = self.encode_images(pipe, image2lora_images)
|
||||
return {"image2lora_x": x}
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRADecode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_x",),
|
||||
output_params=("lora",),
|
||||
onload_model_names=("image2lora_style",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_x):
|
||||
if image2lora_x is None:
|
||||
return {}
|
||||
loras = []
|
||||
if pipe.image2lora_style is not None:
|
||||
pipe.load_models_to_device(["image2lora_style"])
|
||||
for x in image2lora_x:
|
||||
loras.append(pipe.image2lora_style(x=x, residual=None))
|
||||
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
||||
return {"lora": lora}
|
||||
|
||||
|
||||
def model_fn_z_image_turbo(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
image_latents=None,
|
||||
control_context=None,
|
||||
control_scale=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
while isinstance(prompt_embeds, list):
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
while isinstance(latents, list):
|
||||
latents = latents[0]
|
||||
while isinstance(image_embeds, list):
|
||||
image_embeds = image_embeds[0]
|
||||
|
||||
# Timestep
|
||||
timestep = 1000 - timestep
|
||||
t_noisy = dit.t_embedder(timestep)
|
||||
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
|
||||
|
||||
# Patchify
|
||||
latents = rearrange(latents, "B C H W -> C B H W")
|
||||
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
|
||||
x = x[0]
|
||||
cap_feats = cap_feats[0]
|
||||
|
||||
# Noise refine
|
||||
x = dit.all_x_embedder["2-1"](x)
|
||||
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
|
||||
x = rearrange(x, "L C -> 1 L C")
|
||||
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
||||
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
||||
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
for layer_id, layer in enumerate(dit.noise_refiner):
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=x,
|
||||
attn_mask=None,
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
x = x + refiner_hints[layer_id] * control_scale
|
||||
|
||||
# Prompt refine
|
||||
cap_feats = dit.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
|
||||
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
|
||||
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
|
||||
|
||||
for layer in dit.context_refiner:
|
||||
cap_feats = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=cap_feats,
|
||||
attn_mask=None,
|
||||
freqs_cis=cap_freqs_cis,
|
||||
)
|
||||
|
||||
# Unified
|
||||
unified = torch.cat([x, cap_feats], dim=1)
|
||||
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
||||
hints = controlnet.forward_layers(
|
||||
unified, cap_feats, control_context, control_context_item_seqlens, kwargs,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
for layer_id, layer in enumerate(dit.layers):
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=unified,
|
||||
attn_mask=None,
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
if layer_id in controlnet.control_layers_mapping:
|
||||
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
|
||||
|
||||
# Output
|
||||
unified = dit.all_final_layer["2-1"](unified, t_noisy)
|
||||
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
|
||||
x = rearrange(x, "C B H W -> B C H W")
|
||||
x = -x
|
||||
return x
|
||||
|
||||
@@ -9,5 +9,6 @@ class ControlNetInput:
|
||||
start: float = 1.0
|
||||
end: float = 0.0
|
||||
image: Image.Image = None
|
||||
inpaint_image: Image.Image = None
|
||||
inpaint_mask: Image.Image = None
|
||||
processor_id: str = None
|
||||
|
||||
Reference in New Issue
Block a user