Merge pull request #264 from modelscope/dev

Dev
This commit is contained in:
Zhongjie Duan
2024-11-13 09:53:49 +08:00
committed by GitHub
22 changed files with 1742 additions and 610 deletions

View File

@@ -33,13 +33,15 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
from ..models.omnigen import OmniGenTransformer
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
@@ -73,7 +75,7 @@ model_loader_configs = [
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
@@ -81,10 +83,14 @@ model_loader_configs = [
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
# (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai")
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -267,6 +273,13 @@ preset_models_on_huggingface = {
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
# Stable Diffusion 3.5
"StableDiffusion3.5-large": [
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
}
preset_models_on_modelscope = {
# Hunyuan DiT
@@ -536,6 +549,21 @@ preset_models_on_modelscope = {
"RIFE": [
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
],
# Omnigen
"OmniGen-v1": {
"file_list": [
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
],
"load_path": [
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
"models/OmniGen/OmniGen-v1/model.safetensors",
]
},
# CogVideo
"CogVideoX-5B": {
"file_list": [
@@ -555,6 +583,13 @@ preset_models_on_modelscope = {
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
],
},
# Stable Diffusion 3.5
"StableDiffusion3.5-large": [
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
@@ -600,10 +635,12 @@ Preset_model_id: TypeAlias = Literal[
"OmostPrompt",
"ESRGAN_x4",
"RIFE",
"OmniGen-v1",
"CogVideoX-5B",
"Annotators:Depth",
"Annotators:Softedge",
"Annotators:Lineart",
"Annotators:Normal",
"Annotators:Openpose",
"StableDiffusion3.5-large",
]

View File

@@ -1,5 +1,5 @@
import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
from einops import rearrange
from .tiler import TileWorker
from .utils import init_weights_on_device
@@ -37,21 +37,6 @@ class RoPEEmbedding(torch.nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
return hidden_states
class FluxJointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()

View File

@@ -3,26 +3,6 @@ from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
class FluxTextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
hidden_states = embeds
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return embeds, pooled_embeds
@staticmethod
def state_dict_converter():
return FluxTextEncoder1StateDictConverter()
class FluxTextEncoder2(T5EncoderModel):
def __init__(self, config):
@@ -40,47 +20,6 @@ class FluxTextEncoder2(T5EncoderModel):
class FluxTextEncoder1StateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class FluxTextEncoder2StateDictConverter():
def __init__(self):
pass

View File

@@ -37,7 +37,7 @@ from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5Tex
from .hunyuan_dit import HunyuanDiT
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from .flux_text_encoder import FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .cog_vae import CogVAEEncoder, CogVAEDecoder

803
diffsynth/models/omnigen.py Normal file
View File

@@ -0,0 +1,803 @@
# The code is revised from DiT
import os
import torch
import torch.nn as nn
import numpy as np
import math
from safetensors.torch import load_file
from typing import List, Optional, Tuple, Union
import torch.utils.checkpoint
from huggingface_hub import snapshot_download
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers import Phi3Config, Phi3Model
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Phi3Transformer(Phi3Model):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
We only modified the attention mask
Args:
config: Phi3Config
"""
def prefetch_layer(self, layer_idx: int, device: torch.device):
"Starts prefetching the next layer cache"
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
for name, param in self.layers[layer_idx].named_parameters():
param.data = param.data.to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
prev_layer_idx = layer_idx - 1
for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu", non_blocking=True)
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
self.prefetch_stream = torch.cuda.Stream()
# delete previous layer
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# make sure the current layer is ready
torch.cuda.synchronize(self.prefetch_stream)
# load next layer
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
offload_model: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# if cache_position is None:
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# cache_position = torch.arange(
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# )
# if position_ids is None:
# position_ids = cache_position.unsqueeze(0)
if attention_mask is not None and attention_mask.dim() == 3:
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
else:
raise Exception("attention_mask parameter was unavailable or invalid")
# causal_mask = self._update_causal_mask(
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
# )
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
layer_idx = -1
for decoder_layer in self.layers:
layer_idx += 1
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
if offload_model and not self.training:
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
print('************')
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype=torch.float32):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PatchEmbedMR(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(
self,
patch_size: int = 2,
in_chans: int = 4,
embed_dim: int = 768,
bias: bool = True,
):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
return x
class OmniGenOriginalModel(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
transformer_config: Phi3Config,
patch_size=2,
in_channels=4,
pe_interpolation: float = 1.0,
pos_embed_max_size: int = 192,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.pos_embed_max_size = pos_embed_max_size
hidden_size = transformer_config.hidden_size
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
self.time_token = TimestepEmbedder(hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.pe_interpolation = pe_interpolation
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
self.llm = Phi3Transformer(config=transformer_config)
self.llm.config.use_cache = False
@classmethod
def from_pretrained(cls, model_name):
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
config = Phi3Config.from_pretrained(model_name)
model = cls(config)
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
print("Loading safetensors")
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
else:
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
model.load_state_dict(ckpt)
return model
def initialize_weights(self):
assert not hasattr(self, "llama")
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
w = self.input_x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
def cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
# print(top, top + height, left, left + width, spatial_pos_embed.size())
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
if isinstance(latents, list):
return_list = False
if padding_latent is None:
padding_latent = [None] * len(latents)
return_list = True
patched_latents, num_tokens, shapes = [], [], []
for latent, padding in zip(latents, padding_latent):
height, width = latent.shape[-2:]
if is_input_images:
latent = self.input_x_embedder(latent)
else:
latent = self.x_embedder(latent)
pos_embed = self.cropped_pos_embed(height, width)
latent = latent + pos_embed
if padding is not None:
latent = torch.cat([latent, padding], dim=-2)
patched_latents.append(latent)
num_tokens.append(pos_embed.size(1))
shapes.append([height, width])
if not return_list:
latents = torch.cat(patched_latents, dim=0)
else:
latents = patched_latents
else:
height, width = latents.shape[-2:]
if is_input_images:
latents = self.input_x_embedder(latents)
else:
latents = self.x_embedder(latents)
pos_embed = self.cropped_pos_embed(height, width)
latents = latents + pos_embed
num_tokens = latents.size(1)
shapes = [height, width]
return latents, num_tokens, shapes
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
"""
"""
input_is_list = isinstance(x, list)
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
if input_img_latents is not None:
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
if input_ids is not None:
condition_embeds = self.llm.embed_tokens(input_ids).clone()
input_img_inx = 0
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
input_img_inx += 1
if input_img_latents is not None:
assert input_img_inx == len(input_latents)
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
else:
input_emb = torch.cat([time_token, x], dim=1)
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
output, past_key_values = output.last_hidden_state, output.past_key_values
if input_is_list:
image_embedding = output[:, -max(num_tokens):]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = []
for i in range(x.size(0)):
latent = x[i:i+1, :num_tokens[i]]
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
latents.append(latent)
else:
image_embedding = output[:, -num_tokens:]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = self.unpatchify(x, shapes[0], shapes[1])
if return_past_key_values:
return latents, past_key_values
return latents
@torch.no_grad()
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
self.llm.config.use_cache = use_kv_cache
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
if use_img_cfg:
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
model_out = [cond, cond, cond]
else:
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
cond = uncond + cfg_scale * (cond - uncond)
model_out = [cond, cond]
return torch.cat(model_out, dim=0), past_key_values
@torch.no_grad()
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
self.llm.config.use_cache = use_kv_cache
if past_key_values is None:
past_key_values = [None] * len(attention_mask)
x = torch.split(x, len(x) // len(attention_mask), dim=0)
timestep = timestep.to(x[0].dtype)
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
model_out, pask_key_values = [], []
for i in range(len(input_ids)):
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
model_out.append(temp_out)
pask_key_values.append(temp_pask_key_values)
if len(model_out) == 3:
cond, uncond, img_cond = model_out
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
model_out = [cond, cond, cond]
elif len(model_out) == 2:
cond, uncond = model_out
cond = uncond + cfg_scale * (cond - uncond)
model_out = [cond, cond]
else:
return model_out[0]
return torch.cat(model_out, dim=0), pask_key_values
class OmniGenTransformer(OmniGenOriginalModel):
def __init__(self):
config = {
"_name_or_path": "Phi-3-vision-128k-instruct",
"architectures": [
"Phi3ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"model_type": "phi3",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"original_max_position_embeddings": 4096,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0799999237060547,
1.2299998998641968,
1.2299998998641968,
1.2999999523162842,
1.4499999284744263,
1.5999999046325684,
1.6499998569488525,
1.8999998569488525,
2.859999895095825,
3.68999981880188,
5.419999599456787,
5.489999771118164,
5.489999771118164,
9.09000015258789,
11.579999923706055,
15.65999984741211,
15.769999504089355,
15.789999961853027,
18.360000610351562,
21.989999771118164,
23.079999923706055,
30.009998321533203,
32.35000228881836,
32.590003967285156,
35.56000518798828,
39.95000457763672,
53.840003967285156,
56.20000457763672,
57.95000457763672,
59.29000473022461,
59.77000427246094,
59.920005798339844,
61.190006256103516,
61.96000671386719,
62.50000762939453,
63.3700065612793,
63.48000717163086,
63.48000717163086,
63.66000747680664,
63.850006103515625,
64.08000946044922,
64.760009765625,
64.80001068115234,
64.81001281738281,
64.81001281738281
],
"short_factor": [
1.05,
1.05,
1.05,
1.1,
1.1,
1.1,
1.2500000000000002,
1.2500000000000002,
1.4000000000000004,
1.4500000000000004,
1.5500000000000005,
1.8500000000000008,
1.9000000000000008,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.1000000000000005,
2.1000000000000005,
2.2,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3999999999999995,
2.3999999999999995,
2.6499999999999986,
2.6999999999999984,
2.8999999999999977,
2.9499999999999975,
3.049999999999997,
3.049999999999997,
3.049999999999997
],
"type": "su"
},
"rope_theta": 10000.0,
"sliding_window": 131072,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.38.1",
"use_cache": True,
"vocab_size": 32064,
"_attn_implementation": "sdpa"
}
config = Phi3Config(**config)
super().__init__(config)
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
input_is_list = isinstance(x, list)
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
if input_img_latents is not None:
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
if input_ids is not None:
condition_embeds = self.llm.embed_tokens(input_ids).clone()
input_img_inx = 0
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
input_img_inx += 1
if input_img_latents is not None:
assert input_img_inx == len(input_latents)
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
else:
input_emb = torch.cat([time_token, x], dim=1)
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
output, past_key_values = output.last_hidden_state, output.past_key_values
if input_is_list:
image_embedding = output[:, -max(num_tokens):]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = []
for i in range(x.size(0)):
latent = x[i:i+1, :num_tokens[i]]
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
latents.append(latent)
else:
image_embedding = output[:, -num_tokens:]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = self.unpatchify(x, shapes[0], shapes[1])
if return_past_key_values:
return latents, past_key_values
return latents
@torch.no_grad()
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
self.llm.config.use_cache = use_kv_cache
if past_key_values is None:
past_key_values = [None] * len(attention_mask)
x = torch.split(x, len(x) // len(attention_mask), dim=0)
timestep = timestep.to(x[0].dtype)
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
model_out, pask_key_values = [], []
for i in range(len(input_ids)):
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
model_out.append(temp_out)
pask_key_values.append(temp_pask_key_values)
if len(model_out) == 3:
cond, uncond, img_cond = model_out
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
model_out = [cond, cond, cond]
elif len(model_out) == 2:
cond, uncond = model_out
cond = uncond + cfg_scale * (cond - uncond)
model_out = [cond, cond]
else:
return model_out[0]
return torch.cat(model_out, dim=0), pask_key_values
@staticmethod
def state_dict_converter():
return OmniGenTransformerStateDictConverter()
class OmniGenTransformerStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -5,6 +5,21 @@ from .tiler import TileWorker
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
return hidden_states
class PatchEmbed(torch.nn.Module):
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
super().__init__()
@@ -12,7 +27,7 @@ class PatchEmbed(torch.nn.Module):
self.patch_size = patch_size
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
def cropped_pos_embed(self, height, width):
height = height // self.patch_size
@@ -67,7 +82,7 @@ class AdaLayerNorm(torch.nn.Module):
class JointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
@@ -80,12 +95,38 @@ class JointAttention(torch.nn.Module):
if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
if use_rms_norm:
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
else:
self.norm_q_a = None
self.norm_k_a = None
self.norm_q_b = None
self.norm_k_b = None
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
batch_size = hidden_states.shape[0]
qkv = to_qkv(hidden_states)
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
if norm_q is not None:
q = norm_q(q)
if norm_k is not None:
k = norm_k(k)
return q, k, v
def forward(self, hidden_states_a, hidden_states_b):
batch_size = hidden_states_a.shape[0]
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
q = torch.concat([qa, qb], dim=2)
k = torch.concat([ka, kb], dim=2)
v = torch.concat([va, vb], dim=2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -101,12 +142,12 @@ class JointAttention(torch.nn.Module):
class JointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
@@ -145,12 +186,12 @@ class JointTransformerBlock(torch.nn.Module):
class JointTransformerFinalBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim, single=True)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
@@ -177,15 +218,16 @@ class JointTransformerFinalBlock(torch.nn.Module):
class SD3DiT(torch.nn.Module):
def __init__(self):
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False):
super().__init__()
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
self.time_embedder = TimestepEmbeddings(256, 1536)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
self.context_embedder = torch.nn.Linear(4096, 1536)
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
self.norm_out = AdaLayerNorm(1536, single=True)
self.proj_out = torch.nn.Linear(1536, 64)
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=192)
self.time_embedder = TimestepEmbeddings(256, embed_dim)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
self.context_embedder = torch.nn.Linear(4096, embed_dim)
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1)]
+ [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
self.norm_out = AdaLayerNorm(embed_dim, single=True)
self.proj_out = torch.nn.Linear(embed_dim, 64)
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
@@ -238,6 +280,14 @@ class SD3DiTStateDictConverter:
def __init__(self):
pass
def infer_architecture(self, state_dict):
embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
num_layers = 100
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
num_layers -= 1
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
return {"embed_dim": embed_dim, "num_layers": num_layers, "use_rms_norm": use_rms_norm}
def from_diffusers(self, state_dict):
rename_dict = {
"context_embedder": "context_embedder",
@@ -264,12 +314,17 @@ class SD3DiTStateDictConverter:
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "pos_embed.pos_embed":
param = param.reshape((1, 192, 192, 1536))
param = param.reshape((1, 192, 192, param.shape[-1]))
state_dict_[rename_dict[name]] = param
elif name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
@@ -283,7 +338,19 @@ class SD3DiTStateDictConverter:
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
return state_dict_
merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
for key in merged_keys:
param = torch.concat([
state_dict_[key.replace("to_q", "to_q")],
state_dict_[key.replace("to_q", "to_k")],
state_dict_[key.replace("to_q", "to_v")],
], dim=0)
name = key.replace("to_q", "to_qkv")
state_dict_.pop(key.replace("to_q", "to_q"))
state_dict_.pop(key.replace("to_q", "to_k"))
state_dict_.pop(key.replace("to_q", "to_v"))
state_dict_[name] = param
return state_dict_, self.infer_architecture(state_dict_)
def from_civitai(self, state_dict):
rename_dict = {
@@ -291,478 +358,7 @@ class SD3DiTStateDictConverter:
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
@@ -780,19 +376,51 @@ class SD3DiTStateDictConverter:
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
}
for i in range(40):
rename_dict.update({
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
})
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
param = torch.concat([param[1536:], param[:1536]], axis=0)
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
param = torch.concat([param[1536:], param[:1536]], axis=0)
elif name == "model.diffusion_model.pos_embed":
param = param.reshape((1, 192, 192, 1536))
if name == "model.diffusion_model.pos_embed":
param = param.reshape((1, 192, 192, param.shape[-1]))
if isinstance(rename_dict[name], str):
state_dict_[rename_dict[name]] = param
else:
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
state_dict_[name_] = param
return state_dict_
extra_kwargs = self.infer_architecture(state_dict_)
num_layers = extra_kwargs["num_layers"]
for name in [
f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
]:
param = state_dict_[name]
dim = param.shape[0] // 2
param = torch.concat([param[dim:], param[:dim]], axis=0)
state_dict_[name] = param
return state_dict_, self.infer_architecture(state_dict_)

View File

@@ -322,6 +322,11 @@ class SD3TextEncoder1StateDictConverter:
if name == "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif ("text_encoders.clip_l.transformer." + name) in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict["text_encoders.clip_l.transformer." + name]] = param
return state_dict_
@@ -860,6 +865,11 @@ class SD3TextEncoder2StateDictConverter(SDXLTextEncoder2StateDictConverter):
if name == "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif ("text_encoders.clip_g.transformer." + name) in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict["text_encoders.clip_g.transformer." + name]] = param
return state_dict_

View File

@@ -7,5 +7,6 @@ from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline
from .flux_image import FluxImagePipeline
from .cog_video import CogVideoPipeline
from .omnigen_image import OmnigenImagePipeline
from .pipeline_runner import SDVideoPipelineRunner
KolorsImagePipeline = SDXLImagePipeline

View File

@@ -1,19 +1,32 @@
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.cpu_offload = False
self.model_names = []
def check_resize_height_width(self, height, width):
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
return height, width
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
@@ -35,15 +48,18 @@ class BasePipeline(torch.nn.Module):
return video
def merge_latents(self, value, latents, masks, scales):
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1)
value[mask] += latent[mask] * scale
weight[mask] += scale
value /= weight
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
if len(latents) > 0:
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value

View File

@@ -1,4 +1,4 @@
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
@@ -15,11 +15,11 @@ from ..models.tiler import FastTileWorker
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = FlowMatchScheduler()
self.prompter = FluxPrompter()
# models
self.text_encoder_1: FluxTextEncoder1 = None
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
@@ -33,7 +33,7 @@ class FluxImagePipeline(BasePipeline):
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit")
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")

View File

@@ -125,7 +125,7 @@ class ImageSizeManager:
class HunyuanDiTImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
self.prompter = HunyuanDiTPrompter()
self.image_size_manager = ImageSizeManager()

View File

@@ -0,0 +1,287 @@
from ..models.omnigen import OmniGenTransformer
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.model_manager import ModelManager
from ..prompters.omnigen_prompter import OmniGenPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import Optional, Dict, Any, Tuple, List
from transformers.cache_utils import DynamicCache
import torch, os
from tqdm import tqdm
class OmniGenCache(DynamicCache):
def __init__(self,
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
if not torch.cuda.is_available():
print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
offload_kv_cache = False
raise RuntimeError("OffloadedCache can only be used with a GPU")
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.num_tokens_for_img = num_tokens_for_img
self.offload_kv_cache = offload_kv_cache
def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
if layer_idx == 0:
prev_layer_idx = -1
else:
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
if self.offload_kv_cache:
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
# self.prefetch_stream.synchronize(original_device)
torch.cuda.synchronize(self.prefetch_stream)
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
else:
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the cache
if len(self.key_cache) < layer_idx:
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
# only cache the states for condition tokens
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
if self.offload_kv_cache:
self.evict_previous_layer(layer_idx)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
else:
# only cache the states for condition tokens
key_tensor, value_tensor = self[layer_idx]
k = torch.cat([key_tensor, key_states], dim=-2)
v = torch.cat([value_tensor, value_states], dim=-2)
return k, v
class OmnigenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
# models
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.transformer: OmniGenTransformer = None
self.prompter: OmniGenPrompter = None
self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.transformer
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
# Main models
self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = OmnigenImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes=[])
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
return {"encoder_hidden_states": prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
if isinstance(position_ids, list):
for i in range(len(position_ids)):
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
else:
position_ids = position_ids[:, -(num_tokens_for_img+1):]
return position_ids
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
if isinstance(attention_mask, list):
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
return attention_mask[..., -(num_tokens_for_img+1):, :]
@torch.no_grad()
def __call__(
self,
prompt,
reference_images=[],
cfg_scale=2.0,
image_cfg_scale=2.0,
use_kv_cache=True,
offload_kv_cache=True,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = latents.repeat(3, 1, 1, 1)
# Encode prompts
input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
# Encode images
reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
# Pack all parameters
model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
input_img_latents=reference_latents,
input_image_sizes=input_data['input_image_sizes'],
attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
cfg_scale=cfg_scale,
img_cfg_scale=image_cfg_scale,
use_img_cfg=True,
use_kv_cache=use_kv_cache,
offload_model=False,
)
# Denoise
self.load_models_to_device(['transformer'])
cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
# Forward
noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Update KV cache
if progress_id == 0 and use_kv_cache:
num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
if isinstance(cache, list):
model_kwargs['input_ids'] = [None] * len(cache)
else:
model_kwargs['input_ids'] = None
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
del cache
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -10,7 +10,7 @@ from tqdm import tqdm
class SD3ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = FlowMatchScheduler()
self.prompter = SD3Prompter()
# models
@@ -59,9 +59,9 @@ class SD3ImagePipeline(BasePipeline):
return image
def encode_prompt(self, prompt, positive=True):
def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
@@ -84,6 +84,7 @@ class SD3ImagePipeline(BasePipeline):
height=1024,
width=1024,
num_inference_steps=20,
t5_sequence_length=77,
tiled=False,
tile_size=128,
tile_stride=64,
@@ -109,9 +110,9 @@ class SD3ImagePipeline(BasePipeline):
# Encode prompts
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Denoise
self.load_models_to_device(['dit'])

View File

@@ -1,5 +1,6 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.sd3_text_encoder import SD3TextEncoder1
from transformers import CLIPTokenizer, T5TokenizerFast
import os, torch
@@ -19,11 +20,11 @@ class FluxPrompter(BasePrompter):
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
self.text_encoder_1: FluxTextEncoder1 = None
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
def fetch_models(self, text_encoder_1: FluxTextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
@@ -36,7 +37,7 @@ class FluxPrompter(BasePrompter):
max_length=max_length,
truncation=True
).input_ids.to(device)
_, pooled_prompt_emb = text_encoder(input_ids)
pooled_prompt_emb, _ = text_encoder(input_ids)
return pooled_prompt_emb

View File

@@ -0,0 +1,356 @@
import os
import re
from typing import Dict, List
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoTokenizer
from huggingface_hub import snapshot_download
import numpy as np
def crop_arr(pil_image, max_image_size):
while min(*pil_image.size) >= 2 * max_image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
if max(*pil_image.size) > max_image_size:
scale = max_image_size / max(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
if min(*pil_image.size) < 16:
scale = 16 / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y1 = (arr.shape[0] % 16) // 2
crop_y2 = arr.shape[0] % 16 - crop_y1
crop_x1 = (arr.shape[1] % 16) // 2
crop_x2 = arr.shape[1] % 16 - crop_x1
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
return Image.fromarray(arr)
class OmniGenPrompter:
def __init__(self,
text_tokenizer,
max_image_size: int=1024):
self.text_tokenizer = text_tokenizer
self.max_image_size = max_image_size
self.image_transform = transforms.Compose([
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
self.collator = OmniGenCollator()
self.separate_collator = OmniGenSeparateCollator()
@classmethod
def from_pretrained(cls, model_name):
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
allow_patterns="*.json")
text_tokenizer = AutoTokenizer.from_pretrained(model_name)
return cls(text_tokenizer)
def process_image(self, image):
return self.image_transform(image)
def process_multi_modal_prompt(self, text, input_images):
text = self.add_prefix_instruction(text)
if input_images is None or len(input_images) == 0:
model_inputs = self.text_tokenizer(text)
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
pattern = r"<\|image_\d+\|>"
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
for i in range(1, len(prompt_chunks)):
if prompt_chunks[i][0] == 1:
prompt_chunks[i] = prompt_chunks[i][1:]
image_tags = re.findall(pattern, text)
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(list(set(image_ids)))
assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
# total images must be the same as the number of image tags
assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
input_images = [input_images[x-1] for x in image_ids]
all_input_ids = []
img_inx = []
idx = 0
for i in range(len(prompt_chunks)):
all_input_ids.extend(prompt_chunks[i])
if i != len(prompt_chunks) -1:
start_inx = len(all_input_ids)
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
img_inx.append([start_inx, start_inx+size])
all_input_ids.extend([0]*size)
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
def add_prefix_instruction(self, prompt):
user_prompt = '<|user|>\n'
generation_prompt = 'Generate an image according to the following instructions\n'
assistant_prompt = '<|assistant|>\n<|diffusion|>'
prompt_suffix = "<|end|>\n"
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
return prompt
def __call__(self,
instructions: List[str],
input_images: List[List[str]] = None,
height: int = 1024,
width: int = 1024,
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
use_img_cfg: bool = True,
separate_cfg_input: bool = False,
use_input_image_size_as_output: bool=False,
) -> Dict:
if input_images is None:
use_img_cfg = False
if isinstance(instructions, str):
instructions = [instructions]
input_images = [input_images]
input_data = []
for i in range(len(instructions)):
cur_instruction = instructions[i]
cur_input_images = None if input_images is None else input_images[i]
if cur_input_images is not None and len(cur_input_images) > 0:
cur_input_images = [self.process_image(x) for x in cur_input_images]
else:
cur_input_images = None
assert "<img><|image_1|></img>" not in cur_instruction
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
neg_mllm_input, img_cfg_mllm_input = None, None
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
if use_img_cfg:
if cur_input_images is not None and len(cur_input_images) >= 1:
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
else:
img_cfg_mllm_input = neg_mllm_input
if use_input_image_size_as_output:
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
else:
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
if separate_cfg_input:
return self.separate_collator(input_data)
return self.collator(input_data)
class OmniGenCollator:
def __init__(self, pad_token_id=2, hidden_size=3072):
self.pad_token_id = pad_token_id
self.hidden_size = hidden_size
def create_position(self, attention_mask, num_tokens_for_output_images):
position_ids = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
for mask in attention_mask:
temp_l = torch.sum(mask)
temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
position_ids.append(temp_position)
return torch.LongTensor(position_ids)
def create_mask(self, attention_mask, num_tokens_for_output_images):
extended_mask = []
padding_images = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
inx = 0
for mask in attention_mask:
temp_l = torch.sum(mask)
pad_l = text_length - temp_l
temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
image_mask = torch.zeros(size=(temp_l+1, img_length))
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
if pad_l > 0:
pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
pad_mask = torch.ones(size=(pad_l, seq_len))
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
true_img_length = num_tokens_for_output_images[inx]
pad_img_length = img_length - true_img_length
if pad_img_length > 0:
temp_mask[:, -pad_img_length:] = 0
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
else:
temp_padding_imgs = None
extended_mask.append(temp_mask.unsqueeze(0))
padding_images.append(temp_padding_imgs)
inx += 1
return torch.cat(extended_mask, dim=0), padding_images
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
for b_inx in image_sizes.keys():
for start_inx, end_inx in image_sizes[b_inx]:
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
return attention_mask
def pad_input_ids(self, input_ids, image_sizes):
max_l = max([len(x) for x in input_ids])
padded_ids = []
attention_mask = []
new_image_sizes = []
for i in range(len(input_ids)):
temp_ids = input_ids[i]
temp_l = len(temp_ids)
pad_l = max_l - temp_l
if pad_l == 0:
attention_mask.append([1]*max_l)
padded_ids.append(temp_ids)
else:
attention_mask.append([0]*pad_l+[1]*temp_l)
padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
if i in image_sizes:
new_inx = []
for old_inx in image_sizes[i]:
new_inx.append([x+pad_l for x in old_inx])
image_sizes[i] = new_inx
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
def process_mllm_input(self, mllm_inputs, target_img_size):
num_tokens_for_output_images = []
for img_size in target_img_size:
num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
pixel_values, image_sizes = [], {}
b_inx = 0
for x in mllm_inputs:
if x['pixel_values'] is not None:
pixel_values.extend(x['pixel_values'])
for size in x['image_sizes']:
if b_inx not in image_sizes:
image_sizes[b_inx] = [size]
else:
image_sizes[b_inx].append(size)
b_inx += 1
pixel_values = [x.unsqueeze(0) for x in pixel_values]
input_ids = [x['input_ids'] for x in mllm_inputs]
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
def __call__(self, features):
mllm_inputs = [f[0] for f in features]
cfg_mllm_inputs = [f[1] for f in features]
img_cfg_mllm_input = [f[2] for f in features]
target_img_size = [f[3] for f in features]
if img_cfg_mllm_input[0] is not None:
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
target_img_size = target_img_size + target_img_size + target_img_size
else:
mllm_inputs = mllm_inputs + cfg_mllm_inputs
target_img_size = target_img_size + target_img_size
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
data = {"input_ids": all_padded_input_ids,
"attention_mask": all_attention_mask,
"position_ids": all_position_ids,
"input_pixel_values": all_pixel_values,
"input_image_sizes": all_image_sizes,
"padding_images": all_padding_images,
}
return data
class OmniGenSeparateCollator(OmniGenCollator):
def __call__(self, features):
mllm_inputs = [f[0] for f in features]
cfg_mllm_inputs = [f[1] for f in features]
img_cfg_mllm_input = [f[2] for f in features]
target_img_size = [f[3] for f in features]
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
all_padded_input_ids.append(padded_input_ids)
all_attention_mask.append(attention_mask)
all_position_ids.append(position_ids)
all_pixel_values.append(pixel_values)
all_image_sizes.append(image_sizes)
all_padding_images.append(padding_images)
if cfg_mllm_inputs[0] is not None:
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
all_padded_input_ids.append(padded_input_ids)
all_attention_mask.append(attention_mask)
all_position_ids.append(position_ids)
all_pixel_values.append(pixel_values)
all_image_sizes.append(image_sizes)
all_padding_images.append(padding_images)
if img_cfg_mllm_input[0] is not None:
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
all_padded_input_ids.append(padded_input_ids)
all_attention_mask.append(attention_mask)
all_position_ids.append(position_ids)
all_pixel_values.append(pixel_values)
all_image_sizes.append(image_sizes)
all_padding_images.append(padding_images)
data = {"input_ids": all_padded_input_ids,
"attention_mask": all_attention_mask,
"position_ids": all_position_ids,
"input_pixel_values": all_pixel_values,
"input_image_sizes": all_image_sizes,
"padding_images": all_padding_images,
}
return data

View File

@@ -67,7 +67,8 @@ class SD3Prompter(BasePrompter):
self,
prompt,
positive=True,
device="cuda"
device="cuda",
t5_sequence_length=77,
):
prompt = self.process_prompt(prompt, positive=positive)
@@ -77,9 +78,9 @@ class SD3Prompter(BasePrompter):
# T5
if self.text_encoder_3 is None:
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], t5_sequence_length, 4096), dtype=prompt_emb_1.dtype, device=device)
else:
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, t5_sequence_length, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
# Merge

View File

@@ -4,17 +4,20 @@ import torch
class FlowMatchScheduler():
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002):
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
@@ -31,7 +34,7 @@ class FlowMatchScheduler():
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 0
sigma_ = 1 if self.inverse_timesteps else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)

View File

@@ -2,6 +2,14 @@
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
### OmniGen
OmniGen is a text-image-to-image model, you can synthesize an image according to several given reference images.
|Reference image 1|Reference image 2|Synthesized image|
|-|-|-|
|![image_man](https://github.com/user-attachments/assets/35d00493-625b-45d1-ad2b-b558ea09fe36)|![image_woman](https://github.com/user-attachments/assets/abebf69b-3563-4b3b-91c2-c48ff74a29ea)|![image_merged](https://github.com/user-attachments/assets/2979d5a9-f355-4bec-a91d-1e824d9fc8f6)|
### Example: FLUX
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py) and [`flux_text_to_image_low_vram.py`](./flux_text_to_image_low_vram.py)(low VRAM).

View File

@@ -0,0 +1,25 @@
import torch
from diffsynth import ModelManager, OmnigenImagePipeline
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["OmniGen-v1"])
pipe = OmnigenImagePipeline.from_model_manager(model_manager)
image_man = pipe(
prompt="A portrait of a man.",
cfg_scale=2.5, num_inference_steps=50, seed=0
)
image_man.save("image_man.jpg")
image_woman = pipe(
prompt="A portrait of an Asian woman with a white t-shirt.",
cfg_scale=2.5, num_inference_steps=50, seed=1
)
image_woman.save("image_woman.jpg")
image_merged = pipe(
prompt="a man and a woman. The man is the man in <img><|image_1|></img>. The woman is the woman in <img><|image_2|></img>.",
reference_images=[image_man, image_woman],
cfg_scale=2.5, image_cfg_scale=2.5, num_inference_steps=50, seed=2
)
image_merged.save("image_merged.jpg")

View File

@@ -0,0 +1,28 @@
from diffsynth import ModelManager, SD3ImagePipeline
import torch
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["StableDiffusion3.5-large"])
pipe = SD3ImagePipeline.from_model_manager(model_manager)
prompt = "a full body photo of a beautiful Asian girl. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
torch.manual_seed(1)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=5,
num_inference_steps=100, width=1024, height=1024,
)
image.save("image_1024.jpg")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=5,
input_image=image.resize((2048, 2048)), denoising_strength=0.5,
num_inference_steps=50, width=2048, height=2048,
tiled=True
)
image.save("image_2048.jpg")

View File

@@ -22,7 +22,9 @@ class LightningModel(LightningModelForT2ILoRA):
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
model_manager.load_lora(preset_lora_path)
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)

View File

@@ -1,6 +1,7 @@
torch>=2.0.0
torchvision
cupy-cuda12x
transformers==4.44.1
transformers==4.46.2
controlnet-aux==0.0.7
imageio
imageio[ffmpeg]