mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
@@ -33,6 +33,11 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
from ..models.flux_dit import FluxDiT
|
||||
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -62,13 +67,18 @@ model_loader_configs = [
|
||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
||||
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
|
||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai")
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
|
||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
|
||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -133,6 +143,9 @@ preset_models_on_modelscope = {
|
||||
"StableDiffusionXL_Turbo": [
|
||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||
],
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
||||
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
||||
],
|
||||
# Stable Diffusion 3
|
||||
"StableDiffusion3": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||
@@ -157,6 +170,10 @@ preset_models_on_modelscope = {
|
||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_union_sdxl_promax": [
|
||||
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
@@ -214,6 +231,16 @@ preset_models_on_modelscope = {
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
# FLUX
|
||||
"FLUX.1-dev": [
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
]
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
@@ -242,4 +269,7 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
"SDXL-vae-fp16-fix",
|
||||
"ControlNet_union_sdxl_promax",
|
||||
"FLUX.1-dev",
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||
]
|
||||
521
diffsynth/models/flux_dit.py
Normal file
521
diffsynth/models/flux_dit.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
class RoPEEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class FluxJointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.only_out_a = only_out_a
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
if not only_out_a:
|
||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||
|
||||
|
||||
def apply_rope(self, xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
# Part A
|
||||
qkv_a = self.a_to_qkv(hidden_states_a)
|
||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||
|
||||
# Part B
|
||||
qkv_b = self.b_to_qkv(hidden_states_b)
|
||||
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
||||
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
||||
|
||||
q = torch.concat([q_b, q_a], dim=2)
|
||||
k = torch.concat([k_b, k_a], dim=2)
|
||||
v = torch.concat([v_b, v_a], dim=2)
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class FluxJointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class FluxSingleAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
|
||||
def apply_rope(self, xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def forward(self, hidden_states, image_rotary_emb):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv_a = self.a_to_qkv(hidden_states)
|
||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||
|
||||
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormSingle(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.silu = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
|
||||
class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = dim // num_attention_heads
|
||||
self.dim = dim
|
||||
|
||||
self.norm = AdaLayerNormSingle(dim)
|
||||
# self.proj_in = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), torch.nn.GELU(approximate="tanh"))
|
||||
# self.attn = FluxSingleAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||
self.linear = torch.nn.Linear(dim, dim * (3 + 4))
|
||||
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||
|
||||
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
||||
|
||||
|
||||
def apply_rope(self, xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.linear(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb)
|
||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||
|
||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||
hidden_states_a = residual + hidden_states_a
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.silu = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
||||
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, conditioning):
|
||||
emb = self.linear(self.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(3072)
|
||||
self.proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids, **kwargs):
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
||||
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
||||
+ self.pooled_text_embedder(pooled_prompt_emb)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "norm_out.linear",
|
||||
"proj_out": "proj_out",
|
||||
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in rename_dict:
|
||||
state_dict_[rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
print(name)
|
||||
else:
|
||||
print(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"txt_in.bias": "context_embedder.bias",
|
||||
"txt_in.weight": "context_embedder.weight",
|
||||
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
||||
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
||||
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
||||
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
||||
"final_layer.linear.bias": "proj_out.bias",
|
||||
"final_layer.linear.weight": "proj_out.weight",
|
||||
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
||||
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
||||
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
||||
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
||||
"img_in.bias": "x_embedder.bias",
|
||||
"img_in.weight": "x_embedder.weight",
|
||||
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
||||
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
||||
"img_attn.proj.bias": "attn.a_to_out.bias",
|
||||
"img_attn.proj.weight": "attn.a_to_out.weight",
|
||||
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
||||
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
||||
"img_mlp.0.bias": "ff_a.0.bias",
|
||||
"img_mlp.0.weight": "ff_a.0.weight",
|
||||
"img_mlp.2.bias": "ff_a.2.bias",
|
||||
"img_mlp.2.weight": "ff_a.2.weight",
|
||||
"img_mod.lin.bias": "norm1_a.linear.bias",
|
||||
"img_mod.lin.weight": "norm1_a.linear.weight",
|
||||
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
||||
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
||||
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
||||
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
||||
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
||||
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
||||
"txt_mlp.0.bias": "ff_b.0.bias",
|
||||
"txt_mlp.0.weight": "ff_b.0.weight",
|
||||
"txt_mlp.2.bias": "ff_b.2.bias",
|
||||
"txt_mlp.2.weight": "ff_b.2.weight",
|
||||
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
||||
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
||||
|
||||
"linear1.bias": "linear.bias",
|
||||
"linear1.weight": "linear.weight",
|
||||
"linear2.bias": "proj_out.bias",
|
||||
"linear2.weight": "proj_out.weight",
|
||||
"modulation.lin.bias": "norm.linear.bias",
|
||||
"modulation.lin.weight": "norm.linear.weight",
|
||||
"norm.key_norm.scale": "norm_k_a.weight",
|
||||
"norm.query_norm.scale": "norm_q_a.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
if name in rename_dict:
|
||||
rename = rename_dict[name]
|
||||
if name.startswith("final_layer.adaLN_modulation.1."):
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "double_blocks":
|
||||
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "single_blocks":
|
||||
if ".".join(names[2:]) in suffix_rename_dict:
|
||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
print(name)
|
||||
return state_dict_
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
|
||||
|
||||
class FLUXTextEncoder1(SDTextEncoder):
|
||||
class FluxTextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
@@ -20,40 +20,12 @@ class FLUXTextEncoder1(SDTextEncoder):
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FLUXTextEncoder1StateDictConverter()
|
||||
return FluxTextEncoder1StateDictConverter()
|
||||
|
||||
class FLUXTextEncoder2(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(
|
||||
_name_or_path = ".",
|
||||
architectures = ["T5EncoderModel"],
|
||||
classifier_dropout = 0.0,
|
||||
d_ff = 10240,
|
||||
d_kv = 64,
|
||||
d_model = 4096,
|
||||
decoder_start_token_id = 0,
|
||||
dense_act_fn = "gelu_new",
|
||||
dropout_rate = 0.1,
|
||||
eos_token_id = 1,
|
||||
feed_forward_proj = "gated-gelu",
|
||||
initializer_factor = 1.0,
|
||||
is_encoder_decoder = True,
|
||||
is_gated_act = True,
|
||||
layer_norm_epsilon = 1e-06,
|
||||
model_type = "t5",
|
||||
num_decoder_layers = 24,
|
||||
num_heads = 64,
|
||||
num_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
relative_attention_max_distance = 128,
|
||||
relative_attention_num_buckets = 32,
|
||||
tie_word_embeddings = False,
|
||||
torch_dtype = "bfloat16",
|
||||
transformers_version = "4.43.3",
|
||||
use_cache = True,
|
||||
vocab_size = 32128
|
||||
)
|
||||
|
||||
|
||||
class FluxTextEncoder2(T5EncoderModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
@@ -64,10 +36,11 @@ class FLUXTextEncoder2(T5EncoderModel):
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FLUXTextEncoder2StateDictConverter()
|
||||
return FluxTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
class FLUXTextEncoder1StateDictConverter:
|
||||
|
||||
class FluxTextEncoder1StateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -106,7 +79,9 @@ class FLUXTextEncoder1StateDictConverter:
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
class FLUXTextEncoder2StateDictConverter():
|
||||
|
||||
|
||||
class FluxTextEncoder2StateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
303
diffsynth/models/flux_vae.py
Normal file
303
diffsynth/models/flux_vae.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
||||
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class FluxVAEEncoder(SD3VAEEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEDecoder(SD3VAEDecoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"encoder.conv_in.bias": "conv_in.bias",
|
||||
"encoder.conv_in.weight": "conv_in.weight",
|
||||
"encoder.conv_out.bias": "conv_out.bias",
|
||||
"encoder.conv_out.weight": "conv_out.weight",
|
||||
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
|
||||
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"decoder.conv_in.bias": "conv_in.bias",
|
||||
"decoder.conv_in.weight": "conv_in.weight",
|
||||
"decoder.conv_out.bias": "conv_out.bias",
|
||||
"decoder.conv_out.weight": "conv_out.weight",
|
||||
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||
"decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -39,6 +39,10 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
|
||||
|
||||
@@ -83,10 +87,10 @@ def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
if torch.dist(param, param_) < 1e-3:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
||||
return name
|
||||
return None
|
||||
|
||||
@@ -340,8 +344,8 @@ class ModelDetectorFromHuggingfaceFolder:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name)
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
@@ -362,7 +366,9 @@ class ModelDetectorFromHuggingfaceFolder:
|
||||
config = json.load(f)
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for architecture in config["architectures"]:
|
||||
huggingface_lib, model_name = self.architecture_dict[architecture]
|
||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
||||
if redirected_architecture is not None:
|
||||
architecture = redirected_architecture
|
||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
|
||||
@@ -5,5 +5,6 @@ from .sdxl_video import SDXLVideoPipeline
|
||||
from .sd3_image import SD3ImagePipeline
|
||||
from .hunyuan_image import HunyuanDiTImagePipeline
|
||||
from .svd_video import SVDVideoPipeline
|
||||
from .flux_image import FluxImagePipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
|
||||
@@ -22,7 +22,7 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output):
|
||||
image = vae_output[0].cpu().permute(1, 2, 0).numpy()
|
||||
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@@ -136,6 +136,10 @@ def lets_dance_xl(
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 0. Text embedding alignment (only for video processing)
|
||||
if encoder_hidden_states.shape[0] != sample.shape[0]:
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
|
||||
|
||||
# 1. ControlNet
|
||||
controlnet_insert_block_id = 22
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
|
||||
145
diffsynth/pipelines/flux_image.py
Normal file
145
diffsynth/pipelines/flux_image.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
|
||||
from ..prompters import FluxPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.prompter = FluxPrompter()
|
||||
# models
|
||||
self.text_encoder_1: FluxTextEncoder1 = None
|
||||
self.text_encoder_2: FluxTextEncoder2 = None
|
||||
self.dit: FluxDiT = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("flux_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=0.0):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
cfg_scale=0.0,
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=30,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Inference (FLUX doesn't support classifier-free guidance)
|
||||
inference_callback = lambda prompt_emb: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input
|
||||
)
|
||||
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -4,3 +4,4 @@ from .sdxl_prompter import SDXLPrompter
|
||||
from .sd3_prompter import SD3Prompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
from .kolors_prompter import KolorsPrompter
|
||||
from .flux_prompter import FluxPrompter
|
||||
|
||||
74
diffsynth/prompters/flux_prompter.py
Normal file
74
diffsynth/prompters/flux_prompter.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
import os, torch
|
||||
|
||||
|
||||
class FluxPrompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_1_path=None,
|
||||
tokenizer_2_path=None
|
||||
):
|
||||
if tokenizer_1_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_1")
|
||||
if tokenizer_2_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_2")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
|
||||
self.text_encoder_1: FluxTextEncoder1 = None
|
||||
self.text_encoder_2: FluxTextEncoder2 = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder_1: FluxTextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids.to(device)
|
||||
_, pooled_prompt_emb = text_encoder(input_ids)
|
||||
return pooled_prompt_emb
|
||||
|
||||
|
||||
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
).input_ids.to(device)
|
||||
prompt_emb = text_encoder(input_ids)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# CLIP
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
|
||||
# T5
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, 256, device)
|
||||
|
||||
# text_ids
|
||||
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, text_ids
|
||||
48895
diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt
Normal file
48895
diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
49410
diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json
Normal file
49410
diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,125 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
BIN
diffsynth/tokenizer_configs/flux/tokenizer_2/spiece.model
Normal file
BIN
diffsynth/tokenizer_configs/flux/tokenizer_2/spiece.model
Normal file
Binary file not shown.
129428
diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json
Normal file
129428
diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,940 @@
|
||||
{
|
||||
"add_prefix_space": true,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32000": {
|
||||
"content": "<extra_id_99>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32001": {
|
||||
"content": "<extra_id_98>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32002": {
|
||||
"content": "<extra_id_97>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32003": {
|
||||
"content": "<extra_id_96>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32004": {
|
||||
"content": "<extra_id_95>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32005": {
|
||||
"content": "<extra_id_94>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32006": {
|
||||
"content": "<extra_id_93>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32007": {
|
||||
"content": "<extra_id_92>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32008": {
|
||||
"content": "<extra_id_91>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32009": {
|
||||
"content": "<extra_id_90>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32010": {
|
||||
"content": "<extra_id_89>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32011": {
|
||||
"content": "<extra_id_88>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32012": {
|
||||
"content": "<extra_id_87>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32013": {
|
||||
"content": "<extra_id_86>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32014": {
|
||||
"content": "<extra_id_85>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32015": {
|
||||
"content": "<extra_id_84>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32016": {
|
||||
"content": "<extra_id_83>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32017": {
|
||||
"content": "<extra_id_82>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32018": {
|
||||
"content": "<extra_id_81>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32019": {
|
||||
"content": "<extra_id_80>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32020": {
|
||||
"content": "<extra_id_79>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32021": {
|
||||
"content": "<extra_id_78>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32022": {
|
||||
"content": "<extra_id_77>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32023": {
|
||||
"content": "<extra_id_76>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32024": {
|
||||
"content": "<extra_id_75>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32025": {
|
||||
"content": "<extra_id_74>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32026": {
|
||||
"content": "<extra_id_73>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32027": {
|
||||
"content": "<extra_id_72>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32028": {
|
||||
"content": "<extra_id_71>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32029": {
|
||||
"content": "<extra_id_70>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32030": {
|
||||
"content": "<extra_id_69>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32031": {
|
||||
"content": "<extra_id_68>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32032": {
|
||||
"content": "<extra_id_67>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32033": {
|
||||
"content": "<extra_id_66>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32034": {
|
||||
"content": "<extra_id_65>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32035": {
|
||||
"content": "<extra_id_64>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32036": {
|
||||
"content": "<extra_id_63>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32037": {
|
||||
"content": "<extra_id_62>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32038": {
|
||||
"content": "<extra_id_61>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32039": {
|
||||
"content": "<extra_id_60>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32040": {
|
||||
"content": "<extra_id_59>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32041": {
|
||||
"content": "<extra_id_58>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32042": {
|
||||
"content": "<extra_id_57>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32043": {
|
||||
"content": "<extra_id_56>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32044": {
|
||||
"content": "<extra_id_55>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32045": {
|
||||
"content": "<extra_id_54>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32046": {
|
||||
"content": "<extra_id_53>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32047": {
|
||||
"content": "<extra_id_52>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32048": {
|
||||
"content": "<extra_id_51>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32049": {
|
||||
"content": "<extra_id_50>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32050": {
|
||||
"content": "<extra_id_49>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32051": {
|
||||
"content": "<extra_id_48>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32052": {
|
||||
"content": "<extra_id_47>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32053": {
|
||||
"content": "<extra_id_46>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32054": {
|
||||
"content": "<extra_id_45>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32055": {
|
||||
"content": "<extra_id_44>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32056": {
|
||||
"content": "<extra_id_43>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32057": {
|
||||
"content": "<extra_id_42>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32058": {
|
||||
"content": "<extra_id_41>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32059": {
|
||||
"content": "<extra_id_40>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32060": {
|
||||
"content": "<extra_id_39>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32061": {
|
||||
"content": "<extra_id_38>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32062": {
|
||||
"content": "<extra_id_37>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32063": {
|
||||
"content": "<extra_id_36>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32064": {
|
||||
"content": "<extra_id_35>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32065": {
|
||||
"content": "<extra_id_34>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32066": {
|
||||
"content": "<extra_id_33>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32067": {
|
||||
"content": "<extra_id_32>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32068": {
|
||||
"content": "<extra_id_31>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32069": {
|
||||
"content": "<extra_id_30>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32070": {
|
||||
"content": "<extra_id_29>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32071": {
|
||||
"content": "<extra_id_28>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32072": {
|
||||
"content": "<extra_id_27>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32073": {
|
||||
"content": "<extra_id_26>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32074": {
|
||||
"content": "<extra_id_25>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32075": {
|
||||
"content": "<extra_id_24>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32076": {
|
||||
"content": "<extra_id_23>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32077": {
|
||||
"content": "<extra_id_22>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32078": {
|
||||
"content": "<extra_id_21>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32079": {
|
||||
"content": "<extra_id_20>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32080": {
|
||||
"content": "<extra_id_19>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32081": {
|
||||
"content": "<extra_id_18>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32082": {
|
||||
"content": "<extra_id_17>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32083": {
|
||||
"content": "<extra_id_16>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32084": {
|
||||
"content": "<extra_id_15>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32085": {
|
||||
"content": "<extra_id_14>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32086": {
|
||||
"content": "<extra_id_13>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32087": {
|
||||
"content": "<extra_id_12>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32088": {
|
||||
"content": "<extra_id_11>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32089": {
|
||||
"content": "<extra_id_10>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32090": {
|
||||
"content": "<extra_id_9>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32091": {
|
||||
"content": "<extra_id_8>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32092": {
|
||||
"content": "<extra_id_7>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32093": {
|
||||
"content": "<extra_id_6>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32094": {
|
||||
"content": "<extra_id_5>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32095": {
|
||||
"content": "<extra_id_4>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32096": {
|
||||
"content": "<extra_id_3>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32097": {
|
||||
"content": "<extra_id_2>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32098": {
|
||||
"content": "<extra_id_1>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32099": {
|
||||
"content": "<extra_id_0>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "</s>",
|
||||
"extra_ids": 100,
|
||||
"legacy": true,
|
||||
"model_max_length": 512,
|
||||
"pad_token": "<pad>",
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "T5Tokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
@@ -38,6 +38,20 @@ LoRA Training: [`../train/kolors/`](../train/kolors/)
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
Kolors also support the models trained for SD-XL. For example, ControlNets and LoRAs. See [`kolors_with_sdxl_models.py`](./kolors_with_sdxl_models.py)
|
||||
|
||||
LoRA: https://civitai.com/models/73305/zyd232s-ink-style
|
||||
|
||||
|Base model|with LoRA (alpha=0.5)|with LoRA (alpha=1.0)|with LoRA (alpha=1.5)|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
ControlNet: https://huggingface.co/xinsir/controlnet-union-sdxl-1.0
|
||||
|
||||
|Reference image|Depth image|with ControlNet|with ControlNet|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
### Example: Hunyuan-DiT
|
||||
|
||||
Example script: [`hunyuan_dit_text_to_image.py`](./hunyuan_dit_text_to_image.py)
|
||||
|
||||
20
examples/image_synthesis/flux_text_to_image.py
Normal file
20
examples/image_synthesis/flux_text_to_image.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_models
|
||||
|
||||
|
||||
download_models(["FLUX.1-dev"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
torch.manual_seed(6)
|
||||
image = pipe(
|
||||
"A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin.",
|
||||
num_inference_steps=30
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
68
examples/image_synthesis/kolors_with_sdxl_models.py
Normal file
68
examples/image_synthesis/kolors_with_sdxl_models.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from diffsynth import ModelManager, SDXLImagePipeline, download_models, ControlNetConfigUnit
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def run_kolors_with_controlnet():
|
||||
download_models(["Kolors", "ControlNet_union_sdxl_promax"])
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
|
||||
file_path_list=[
|
||||
"models/kolors/Kolors/text_encoder",
|
||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/controlnet_union/diffusion_pytorch_model_promax.safetensors",
|
||||
])
|
||||
pipe = SDXLImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit("depth", "models/ControlNet/controlnet_union/diffusion_pytorch_model_promax.safetensors", 0.6)
|
||||
])
|
||||
negative_prompt = "半身,苍白的肤色,蜡黄的肤色,尸体,错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,错误的手指,口红,腮红"
|
||||
|
||||
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、蓝色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女漂浮在水下,面向镜头,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
|
||||
torch.manual_seed(7)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
|
||||
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、黑色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女,面向镜头,周围是绚烂的火焰"
|
||||
torch.manual_seed(0)
|
||||
image_controlnet = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
|
||||
controlnet_image=image,
|
||||
)
|
||||
image_controlnet.save("image_depth_1.jpg")
|
||||
|
||||
prompt = "一幅充满诗意美感的全身画,画中一位皮肤白皙、黑色长发、黑色眼睛、身穿金色吊带连衣裙的少女,周围是闪电,画面明亮"
|
||||
torch.manual_seed(1)
|
||||
image_controlnet = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
|
||||
controlnet_image=image,
|
||||
)
|
||||
image_controlnet.save("image_depth_2.jpg")
|
||||
|
||||
|
||||
|
||||
def run_kolors_with_lora():
|
||||
download_models(["Kolors", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0"])
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
|
||||
file_path_list=[
|
||||
"models/kolors/Kolors/text_encoder",
|
||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
|
||||
])
|
||||
model_manager.load_lora("models/lora/zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", lora_alpha=1.5)
|
||||
pipe = SDXLImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、蓝色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女漂浮在水下,面向镜头,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
|
||||
negative_prompt = "半身,苍白的肤色,蜡黄的肤色,尸体,错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,错误的手指,口红,腮红"
|
||||
|
||||
torch.manual_seed(7)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, cfg_scale=4,
|
||||
)
|
||||
image.save("image_lora.jpg")
|
||||
|
||||
|
||||
|
||||
run_kolors_with_controlnet()
|
||||
run_kolors_with_lora()
|
||||
Reference in New Issue
Block a user