Merge pull request #160 from modelscope/Artiprocher-dev

support FLUX
This commit is contained in:
Zhongjie Duan
2024-08-17 17:52:59 +08:00
committed by GitHub
22 changed files with 230068 additions and 48 deletions

View File

@@ -33,6 +33,11 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -62,13 +67,18 @@ model_loader_configs = [
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai")
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -133,6 +143,9 @@ preset_models_on_modelscope = {
"StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
@@ -157,6 +170,10 @@ preset_models_on_modelscope = {
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
],
"ControlNet_union_sdxl_promax": [
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
@@ -214,6 +231,16 @@ preset_models_on_modelscope = {
"SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
# FLUX
"FLUX.1-dev": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
]
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
@@ -242,4 +269,7 @@ Preset_model_id: TypeAlias = Literal[
"StableDiffusion3_without_T5",
"Kolors",
"SDXL-vae-fp16-fix",
"ControlNet_union_sdxl_promax",
"FLUX.1-dev",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
]

View File

@@ -0,0 +1,521 @@
import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from einops import rearrange
class RoPEEmbedding(torch.nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
return hidden_states
class FluxJointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.only_out_a = only_out_a
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
batch_size = hidden_states_a.shape[0]
# Part A
qkv_a = self.a_to_qkv(hidden_states_a)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
# Part B
qkv_b = self.b_to_qkv(hidden_states_b)
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
q = torch.concat([q_b, q_a], dim=2)
k = torch.concat([k_b, k_a], dim=2)
v = torch.concat([v_b, v_a], dim=2)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxJointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim)
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
# Part B
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv_a = self.a_to_qkv(hidden_states)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
class AdaLayerNormSingle(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class FluxSingleTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = dim // num_attention_heads
self.dim = dim
self.norm = AdaLayerNormSingle(dim)
# self.proj_in = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), torch.nn.GELU(approximate="tanh"))
# self.attn = FluxSingleAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.linear = torch.nn.Linear(dim, dim * (3 + 4))
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
self.proj_out = torch.nn.Linear(dim * 5, dim)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
q, k = self.norm_q_a(q), self.norm_k_a(k)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.linear(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
class AdaLayerNormContinuous(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
def forward(self, x, conditioning):
emb = self.linear(self.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
return x
class FluxDiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.norm_out = AdaLayerNormContinuous(3072)
self.proj_out = torch.nn.Linear(3072, 64)
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids, **kwargs):
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
+ self.guidance_embedder(guidance, hidden_states.dtype)\
+ self.pooled_text_embedder(pooled_prompt_emb)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
for block in self.blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, conditioning)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return FluxDiTStateDictConverter()
class FluxDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "norm_out.linear",
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
elif name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in rename_dict:
state_dict_[rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
print(name)
else:
print(name)
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
"txt_in.bias": "context_embedder.bias",
"txt_in.weight": "context_embedder.weight",
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "proj_out.bias",
"final_layer.linear.weight": "proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
}
suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
"img_attn.proj.bias": "attn.a_to_out.bias",
"img_attn.proj.weight": "attn.a_to_out.weight",
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
"img_mlp.0.bias": "ff_a.0.bias",
"img_mlp.0.weight": "ff_a.0.weight",
"img_mlp.2.bias": "ff_a.2.bias",
"img_mlp.2.weight": "ff_a.2.weight",
"img_mod.lin.bias": "norm1_a.linear.bias",
"img_mod.lin.weight": "norm1_a.linear.weight",
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
"txt_attn.proj.bias": "attn.b_to_out.bias",
"txt_attn.proj.weight": "attn.b_to_out.weight",
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
"txt_mlp.0.bias": "ff_b.0.bias",
"txt_mlp.0.weight": "ff_b.0.weight",
"txt_mlp.2.bias": "ff_b.2.bias",
"txt_mlp.2.weight": "ff_b.2.weight",
"txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "linear.bias",
"linear1.weight": "linear.weight",
"linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias",
"modulation.lin.weight": "norm.linear.weight",
"norm.key_norm.scale": "norm_k_a.weight",
"norm.query_norm.scale": "norm_q_a.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
if name in rename_dict:
rename = rename_dict[name]
if name.startswith("final_layer.adaLN_modulation.1."):
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[rename] = param
elif names[0] == "double_blocks":
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
elif names[0] == "single_blocks":
if ".".join(names[2:]) in suffix_rename_dict:
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
else:
print(name)
return state_dict_

View File

@@ -1,9 +1,9 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
class FLUXTextEncoder1(SDTextEncoder):
class FluxTextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size)
@@ -20,40 +20,12 @@ class FLUXTextEncoder1(SDTextEncoder):
@staticmethod
def state_dict_converter():
return FLUXTextEncoder1StateDictConverter()
return FluxTextEncoder1StateDictConverter()
class FLUXTextEncoder2(T5EncoderModel):
def __init__(self):
config = T5Config(
_name_or_path = ".",
architectures = ["T5EncoderModel"],
classifier_dropout = 0.0,
d_ff = 10240,
d_kv = 64,
d_model = 4096,
decoder_start_token_id = 0,
dense_act_fn = "gelu_new",
dropout_rate = 0.1,
eos_token_id = 1,
feed_forward_proj = "gated-gelu",
initializer_factor = 1.0,
is_encoder_decoder = True,
is_gated_act = True,
layer_norm_epsilon = 1e-06,
model_type = "t5",
num_decoder_layers = 24,
num_heads = 64,
num_layers = 24,
output_past = True,
pad_token_id = 0,
relative_attention_max_distance = 128,
relative_attention_num_buckets = 32,
tie_word_embeddings = False,
torch_dtype = "bfloat16",
transformers_version = "4.43.3",
use_cache = True,
vocab_size = 32128
)
class FluxTextEncoder2(T5EncoderModel):
def __init__(self, config):
super().__init__(config)
self.eval()
@@ -64,10 +36,11 @@ class FLUXTextEncoder2(T5EncoderModel):
@staticmethod
def state_dict_converter():
return FLUXTextEncoder2StateDictConverter()
return FluxTextEncoder2StateDictConverter()
class FLUXTextEncoder1StateDictConverter:
class FluxTextEncoder1StateDictConverter:
def __init__(self):
pass
@@ -106,7 +79,9 @@ class FLUXTextEncoder1StateDictConverter:
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class FLUXTextEncoder2StateDictConverter():
class FluxTextEncoder2StateDictConverter():
def __init__(self):
pass

View File

@@ -0,0 +1,303 @@
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
class FluxVAEEncoder(SD3VAEEncoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEEncoderStateDictConverter()
class FluxVAEDecoder(SD3VAEDecoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEDecoderStateDictConverter()
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"encoder.conv_in.bias": "conv_in.bias",
"encoder.conv_in.weight": "conv_in.weight",
"encoder.conv_out.bias": "conv_out.bias",
"encoder.conv_out.weight": "conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
"encoder.norm_out.bias": "conv_norm_out.bias",
"encoder.norm_out.weight": "conv_norm_out.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"decoder.conv_in.bias": "conv_in.bias",
"decoder.conv_in.weight": "conv_in.weight",
"decoder.conv_out.bias": "conv_out.bias",
"decoder.conv_out.weight": "conv_out.weight",
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
"decoder.norm_out.bias": "conv_norm_out.bias",
"decoder.norm_out.weight": "conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -39,6 +39,10 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
@@ -83,10 +87,10 @@ def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
@@ -340,8 +344,8 @@ class ModelDetectorFromHuggingfaceFolder:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name):
self.architecture_dict[architecture] = (huggingface_lib, model_name)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
@@ -362,7 +366,9 @@ class ModelDetectorFromHuggingfaceFolder:
config = json.load(f)
loaded_model_names, loaded_models = [], []
for architecture in config["architectures"]:
huggingface_lib, model_name = self.architecture_dict[architecture]
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_

View File

@@ -5,5 +5,6 @@ from .sdxl_video import SDXLVideoPipeline
from .sd3_image import SD3ImagePipeline
from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline
from .flux_image import FluxImagePipeline
from .pipeline_runner import SDVideoPipelineRunner
KolorsImagePipeline = SDXLImagePipeline

View File

@@ -22,7 +22,7 @@ class BasePipeline(torch.nn.Module):
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().permute(1, 2, 0).numpy()
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image

View File

@@ -136,6 +136,10 @@ def lets_dance_xl(
device = "cuda",
vram_limit_level = 0,
):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
# 1. ControlNet
controlnet_insert_block_id = 22
if controlnet is not None and controlnet_frames is not None:

View File

@@ -0,0 +1,145 @@
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler()
self.prompter = FluxPrompter()
# models
self.text_encoder_1: FluxTextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit")
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
pipe = FluxImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, positive=True):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
def prepare_extra_input(self, latents=None, guidance=0.0):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
cfg_scale=0.0,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=30,
tiled=False,
tile_size=128,
tile_stride=64,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb = self.encode_prompt(prompt, positive=True)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Inference (FLUX doesn't support classifier-free guidance)
inference_callback = lambda prompt_emb: self.dit(
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input
)
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback)
# DDIM
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -4,3 +4,4 @@ from .sdxl_prompter import SDXLPrompter
from .sd3_prompter import SD3Prompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter
from .kolors_prompter import KolorsPrompter
from .flux_prompter import FluxPrompter

View File

@@ -0,0 +1,74 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from transformers import CLIPTokenizer, T5TokenizerFast
import os, torch
class FluxPrompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
tokenizer_2_path=None
):
if tokenizer_1_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_1")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_2")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
self.text_encoder_1: FluxTextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
def fetch_models(self, text_encoder_1: FluxTextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids.to(device)
_, pooled_prompt_emb = text_encoder(input_ids)
return pooled_prompt_emb
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def encode_prompt(
self,
prompt,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, 256, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
return prompt_emb, pooled_prompt_emb, text_ids

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,30 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "<|endoftext|>",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,940 @@
{
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": true,
"model_max_length": 512,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}

View File

@@ -38,6 +38,20 @@ LoRA Training: [`../train/kolors/`](../train/kolors/)
|-|-|
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/66bb7a75-fe31-44e5-90eb-d3140ee4686d)|
Kolors also support the models trained for SD-XL. For example, ControlNets and LoRAs. See [`kolors_with_sdxl_models.py`](./kolors_with_sdxl_models.py)
LoRA: https://civitai.com/models/73305/zyd232s-ink-style
|Base model|with LoRA (alpha=0.5)|with LoRA (alpha=1.0)|with LoRA (alpha=1.5)|
|-|-|-|-|
|![image_0 0](https://github.com/user-attachments/assets/a222eae3-6e0a-4ea6-b301-99e74e2bc11a)|![image_0 5](https://github.com/user-attachments/assets/e429c501-530c-43f6-a30b-9f97996c91a2)|![image_1 0](https://github.com/user-attachments/assets/0ddeed4b-250d-4b5c-a4fa-2db50f63bf1c)|![image_1 5](https://github.com/user-attachments/assets/db35a89d-6325-4422-921e-14fb6ad66c92)|
ControlNet: https://huggingface.co/xinsir/controlnet-union-sdxl-1.0
|Reference image|Depth image|with ControlNet|with ControlNet|
|-|-|-|-|
|![image_0 0](https://github.com/user-attachments/assets/a222eae3-6e0a-4ea6-b301-99e74e2bc11a)|![controlnet_input](https://github.com/user-attachments/assets/d16b2785-bc1f-4184-b170-ae90f1d704c1)|![image_depth_1](https://github.com/user-attachments/assets/90a94780-7b56-4786-8a25-aae118eda171)|![image_depth_2](https://github.com/user-attachments/assets/05eb1309-9c98-49e7-a8ee-f376ceedf18e)|
### Example: Hunyuan-DiT
Example script: [`hunyuan_dit_text_to_image.py`](./hunyuan_dit_text_to_image.py)

View File

@@ -0,0 +1,20 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models
download_models(["FLUX.1-dev"])
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
torch.manual_seed(6)
image = pipe(
"A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin.",
num_inference_steps=30
)
image.save("image_1024.jpg")

View File

@@ -0,0 +1,68 @@
from diffsynth import ModelManager, SDXLImagePipeline, download_models, ControlNetConfigUnit
import torch
def run_kolors_with_controlnet():
download_models(["Kolors", "ControlNet_union_sdxl_promax"])
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
file_path_list=[
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
"models/ControlNet/controlnet_union/diffusion_pytorch_model_promax.safetensors",
])
pipe = SDXLImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit("depth", "models/ControlNet/controlnet_union/diffusion_pytorch_model_promax.safetensors", 0.6)
])
negative_prompt = "半身,苍白的肤色,蜡黄的肤色,尸体,错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,错误的手指,口红,腮红"
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、蓝色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女漂浮在水下,面向镜头,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
torch.manual_seed(7)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
)
image.save("image.jpg")
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、黑色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女,面向镜头,周围是绚烂的火焰"
torch.manual_seed(0)
image_controlnet = pipe(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
controlnet_image=image,
)
image_controlnet.save("image_depth_1.jpg")
prompt = "一幅充满诗意美感的全身画,画中一位皮肤白皙、黑色长发、黑色眼睛、身穿金色吊带连衣裙的少女,周围是闪电,画面明亮"
torch.manual_seed(1)
image_controlnet = pipe(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
controlnet_image=image,
)
image_controlnet.save("image_depth_2.jpg")
def run_kolors_with_lora():
download_models(["Kolors", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0"])
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
file_path_list=[
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
])
model_manager.load_lora("models/lora/zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", lora_alpha=1.5)
pipe = SDXLImagePipeline.from_model_manager(model_manager)
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、蓝色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女漂浮在水下,面向镜头,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
negative_prompt = "半身,苍白的肤色,蜡黄的肤色,尸体,错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,错误的手指,口红,腮红"
torch.manual_seed(7)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
)
image.save("image_lora.jpg")
run_kolors_with_controlnet()
run_kolors_with_lora()