mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
support step1x
This commit is contained in:
@@ -62,6 +62,8 @@ from ..models.wan_video_vae import WanVideoVAE
|
|||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
from ..models.wan_video_vace import VaceWanModel
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
|
|
||||||
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -136,6 +138,7 @@ model_loader_configs = [
|
|||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||||
|
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -151,6 +154,7 @@ huggingface_model_loader_configs = [
|
|||||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||||
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||||
|
("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
|
||||||
]
|
]
|
||||||
patch_model_loader_configs = [
|
patch_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
168
diffsynth/models/qwenvl.py
Normal file
168
diffsynth/models/qwenvl.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
||||||
|
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||||
|
super(Qwen25VL_7b_Embedder, self).__init__()
|
||||||
|
self.max_length = max_length
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
).to(torch.cuda.current_device())
|
||||||
|
|
||||||
|
self.model.requires_grad_(False)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
||||||
|
)
|
||||||
|
|
||||||
|
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
||||||
|
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
||||||
|
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
||||||
|
Here are examples of how to transform or refine prompts:
|
||||||
|
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
||||||
|
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
||||||
|
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
||||||
|
User Prompt:'''
|
||||||
|
|
||||||
|
self.prefix = Qwen25VL_7b_PREFIX
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"):
|
||||||
|
return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, caption, ref_images):
|
||||||
|
text_list = caption
|
||||||
|
embs = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
masks = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
input_ids_list = []
|
||||||
|
attention_mask_list = []
|
||||||
|
emb_list = []
|
||||||
|
|
||||||
|
def split_string(s):
|
||||||
|
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
||||||
|
result = []
|
||||||
|
in_quotes = False
|
||||||
|
temp = ""
|
||||||
|
|
||||||
|
for idx,char in enumerate(s):
|
||||||
|
if char == '"' and idx>155:
|
||||||
|
temp += char
|
||||||
|
if not in_quotes:
|
||||||
|
result.append(temp)
|
||||||
|
temp = ""
|
||||||
|
|
||||||
|
in_quotes = not in_quotes
|
||||||
|
continue
|
||||||
|
if in_quotes:
|
||||||
|
if char.isspace():
|
||||||
|
pass # have space token
|
||||||
|
|
||||||
|
result.append("“" + char + "”")
|
||||||
|
else:
|
||||||
|
temp += char
|
||||||
|
|
||||||
|
if temp:
|
||||||
|
result.append(temp)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": []}]
|
||||||
|
|
||||||
|
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
||||||
|
|
||||||
|
messages[0]["content"].append({"type": "image", "image": imgs})
|
||||||
|
|
||||||
|
# 再添加 text
|
||||||
|
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
||||||
|
|
||||||
|
# Preparation for inference
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs = [imgs]
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=image_inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
old_inputs_ids = inputs.input_ids
|
||||||
|
text_split_list = split_string(text)
|
||||||
|
|
||||||
|
token_list = []
|
||||||
|
for text_each in text_split_list:
|
||||||
|
txt_inputs = self.processor(
|
||||||
|
text=text_each,
|
||||||
|
images=None,
|
||||||
|
videos=None,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
token_each = txt_inputs.input_ids
|
||||||
|
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
||||||
|
token_each = token_each[:, 1:-1]
|
||||||
|
token_list.append(token_each)
|
||||||
|
else:
|
||||||
|
token_list.append(token_each)
|
||||||
|
|
||||||
|
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||||
|
|
||||||
|
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||||
|
|
||||||
|
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||||
|
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||||
|
inputs.input_ids = (
|
||||||
|
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to("cuda")
|
||||||
|
)
|
||||||
|
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=inputs.input_ids,
|
||||||
|
attention_mask=inputs.attention_mask,
|
||||||
|
pixel_values=inputs.pixel_values.to("cuda"),
|
||||||
|
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
emb = outputs["hidden_states"][-1]
|
||||||
|
|
||||||
|
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
||||||
|
: self.max_length
|
||||||
|
]
|
||||||
|
|
||||||
|
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||||
|
(min(self.max_length, emb.shape[1] - 217)),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return embs, masks
|
||||||
683
diffsynth/models/step1x_connector.py
Normal file
683
diffsynth/models/step1x_connector.py
Normal file
@@ -0,0 +1,683 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch, math
|
||||||
|
import torch.nn
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from functools import partial
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v, attn_mask, mode="torch"):
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
hidden_channels=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=None,
|
||||||
|
bias=True,
|
||||||
|
drop=0.0,
|
||||||
|
use_conv=False,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_channels
|
||||||
|
hidden_channels = hidden_channels or in_channels
|
||||||
|
bias = (bias, bias)
|
||||||
|
drop_probs = (drop, drop)
|
||||||
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||||
|
|
||||||
|
self.fc1 = linear_layer(
|
||||||
|
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
|
self.norm = (
|
||||||
|
norm_layer(hidden_channels, device=device, dtype=dtype)
|
||||||
|
if norm_layer is not None
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.fc2 = linear_layer(
|
||||||
|
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TextProjection(nn.Module):
|
||||||
|
"""
|
||||||
|
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
||||||
|
|
||||||
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
|
in_features=in_channels,
|
||||||
|
out_features=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.act_1 = act_layer()
|
||||||
|
self.linear_2 = nn.Linear(
|
||||||
|
in_features=hidden_size,
|
||||||
|
out_features=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
hidden_states = self.linear_1(caption)
|
||||||
|
hidden_states = self.act_1(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
act_layer,
|
||||||
|
frequency_embedding_size=256,
|
||||||
|
max_period=10000,
|
||||||
|
out_size=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.max_period = max_period
|
||||||
|
if out_size is None:
|
||||||
|
out_size = hidden_size
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
||||||
|
),
|
||||||
|
act_layer(),
|
||||||
|
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
||||||
|
)
|
||||||
|
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
|
||||||
|
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||||
|
dim (int): the dimension of the output.
|
||||||
|
max_period (int): controls the minimum frequency of the embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
||||||
|
|
||||||
|
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period)
|
||||||
|
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||||
|
/ half
|
||||||
|
).to(device=t.device)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t_freq = self.timestep_embedding(
|
||||||
|
t, self.frequency_embedding_size, self.max_period
|
||||||
|
).type(self.mlp[0].weight.dtype) # type: ignore
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gate(x, gate=None, tanh=False):
|
||||||
|
"""AI is creating summary for apply_gate
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input tensor.
|
||||||
|
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
||||||
|
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the output tensor after apply gate.
|
||||||
|
"""
|
||||||
|
if gate is None:
|
||||||
|
return x
|
||||||
|
if tanh:
|
||||||
|
return x * gate.unsqueeze(1).tanh()
|
||||||
|
else:
|
||||||
|
return x * gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RMSNorm normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input tensor.
|
||||||
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
eps (float): A small value added to the denominator for numerical stability.
|
||||||
|
weight (nn.Parameter): Learnable scaling parameter.
|
||||||
|
|
||||||
|
"""
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
"""
|
||||||
|
Apply the RMSNorm normalization to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the RMSNorm layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output tensor after applying RMSNorm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
if hasattr(self, "weight"):
|
||||||
|
output = output * self.weight
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_layer(norm_layer):
|
||||||
|
"""
|
||||||
|
Get the normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
norm_layer (str): The type of normalization layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
norm_layer (nn.Module): The normalization layer.
|
||||||
|
"""
|
||||||
|
if norm_layer == "layer":
|
||||||
|
return nn.LayerNorm
|
||||||
|
elif norm_layer == "rms":
|
||||||
|
return RMSNorm
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_layer(act_type):
|
||||||
|
"""get activation layer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
act_type (str): the activation type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.functional: the activation layer
|
||||||
|
"""
|
||||||
|
if act_type == "gelu":
|
||||||
|
return lambda: nn.GELU()
|
||||||
|
elif act_type == "gelu_tanh":
|
||||||
|
return lambda: nn.GELU(approximate="tanh")
|
||||||
|
elif act_type == "relu":
|
||||||
|
return nn.ReLU
|
||||||
|
elif act_type == "silu":
|
||||||
|
return nn.SiLU
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation type: {act_type}")
|
||||||
|
|
||||||
|
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
mlp_width_ratio: str = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
need_CA: bool = False,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.need_CA = need_CA
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.self_attn_qkv = nn.Linear(
|
||||||
|
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||||
|
self.self_attn_q_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_k_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_proj = nn.Linear(
|
||||||
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
act_layer = get_activation_layer(act_type)
|
||||||
|
self.mlp = MLP(
|
||||||
|
in_channels=hidden_size,
|
||||||
|
hidden_channels=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=mlp_drop_rate,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
act_layer(),
|
||||||
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.need_CA:
|
||||||
|
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
|
||||||
|
heads_num=heads_num,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
mlp_drop_rate=mlp_drop_rate,
|
||||||
|
act_type=act_type,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
**factory_kwargs,)
|
||||||
|
# Zero-initialize the modulation
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
y: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
|
norm_x = self.norm1(x)
|
||||||
|
qkv = self.self_attn_qkv(norm_x)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
# Apply QK-Norm if needed
|
||||||
|
q = self.self_attn_q_norm(q).to(v)
|
||||||
|
k = self.self_attn_k_norm(k).to(v)
|
||||||
|
|
||||||
|
# Self-Attention
|
||||||
|
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||||
|
|
||||||
|
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||||
|
|
||||||
|
if self.need_CA:
|
||||||
|
x = self.cross_attnblock(x, c, attn_mask, y)
|
||||||
|
|
||||||
|
# FFN Layer
|
||||||
|
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttnBlock(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
mlp_width_ratio: str = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.norm1_2 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.self_attn_q = nn.Linear(
|
||||||
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.self_attn_kv = nn.Linear(
|
||||||
|
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||||
|
self.self_attn_q_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_k_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_proj = nn.Linear(
|
||||||
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
act_layer = get_activation_layer(act_type)
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
act_layer(),
|
||||||
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||||
|
)
|
||||||
|
# Zero-initialize the modulation
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
y: torch.Tensor=None,
|
||||||
|
|
||||||
|
):
|
||||||
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
|
norm_x = self.norm1(x)
|
||||||
|
norm_y = self.norm1_2(y)
|
||||||
|
q = self.self_attn_q(norm_x)
|
||||||
|
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
||||||
|
kv = self.self_attn_kv(norm_y)
|
||||||
|
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
|
||||||
|
# Apply QK-Norm if needed
|
||||||
|
q = self.self_attn_q_norm(q).to(v)
|
||||||
|
k = self.self_attn_k_norm(k).to(v)
|
||||||
|
|
||||||
|
# Self-Attention
|
||||||
|
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||||
|
|
||||||
|
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualTokenRefiner(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
depth,
|
||||||
|
mlp_width_ratio: float = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
need_CA:bool=False,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.need_CA = need_CA
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
IndividualTokenRefinerBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
heads_num=heads_num,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
mlp_drop_rate=mlp_drop_rate,
|
||||||
|
act_type=act_type,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
need_CA=self.need_CA,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.LongTensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
y:torch.Tensor=None,
|
||||||
|
):
|
||||||
|
self_attn_mask = None
|
||||||
|
if mask is not None:
|
||||||
|
batch_size = mask.shape[0]
|
||||||
|
seq_len = mask.shape[1]
|
||||||
|
mask = mask.to(x.device)
|
||||||
|
# batch_size x 1 x seq_len x seq_len
|
||||||
|
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
||||||
|
1, 1, seq_len, 1
|
||||||
|
)
|
||||||
|
# batch_size x 1 x seq_len x seq_len
|
||||||
|
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||||
|
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
||||||
|
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||||
|
# avoids self-attention weight being NaN for padding tokens
|
||||||
|
self_attn_mask[:, :, :, 0] = True
|
||||||
|
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, c, self_attn_mask,y)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleTokenRefiner(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A single token refiner block for llm text embedding refine.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
depth,
|
||||||
|
mlp_width_ratio: float = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
need_CA:bool=False,
|
||||||
|
attn_mode: str = "torch",
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.need_CA = need_CA
|
||||||
|
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
||||||
|
|
||||||
|
self.input_embedder = nn.Linear(
|
||||||
|
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||||
|
)
|
||||||
|
if self.need_CA:
|
||||||
|
self.input_embedder_CA = nn.Linear(
|
||||||
|
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
act_layer = get_activation_layer(act_type)
|
||||||
|
# Build timestep embedding layer
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
||||||
|
# Build context embedding layer
|
||||||
|
self.c_embedder = TextProjection(
|
||||||
|
in_channels, hidden_size, act_layer, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.individual_token_refiner = IndividualTokenRefiner(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
heads_num=heads_num,
|
||||||
|
depth=depth,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
mlp_drop_rate=mlp_drop_rate,
|
||||||
|
act_type=act_type,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
need_CA=need_CA,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.LongTensor,
|
||||||
|
mask: Optional[torch.LongTensor] = None,
|
||||||
|
y: torch.LongTensor=None,
|
||||||
|
):
|
||||||
|
timestep_aware_representations = self.t_embedder(t)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
context_aware_representations = x.mean(dim=1)
|
||||||
|
else:
|
||||||
|
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||||
|
context_aware_representations = (x * mask_float).sum(
|
||||||
|
dim=1
|
||||||
|
) / mask_float.sum(dim=1)
|
||||||
|
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||||
|
c = timestep_aware_representations + context_aware_representations
|
||||||
|
|
||||||
|
x = self.input_embedder(x)
|
||||||
|
if self.need_CA:
|
||||||
|
y = self.input_embedder_CA(y)
|
||||||
|
x = self.individual_token_refiner(x, c, mask, y)
|
||||||
|
else:
|
||||||
|
x = self.individual_token_refiner(x, c, mask)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Connector(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# biclip_dim=1024,
|
||||||
|
in_channels=3584,
|
||||||
|
hidden_size=4096,
|
||||||
|
heads_num=32,
|
||||||
|
depth=2,
|
||||||
|
need_CA=False,
|
||||||
|
device=None,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
factory_kwargs = {"device": device, "dtype":dtype}
|
||||||
|
|
||||||
|
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
|
||||||
|
self.global_proj_out=nn.Linear(in_channels,768)
|
||||||
|
|
||||||
|
self.scale_factor = nn.Parameter(torch.zeros(1))
|
||||||
|
with torch.no_grad():
|
||||||
|
self.scale_factor.data += -(1 - 0.09)
|
||||||
|
|
||||||
|
def forward(self, x,t,mask):
|
||||||
|
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||||
|
x_mean = (x * mask_float).sum(
|
||||||
|
dim=1
|
||||||
|
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
||||||
|
|
||||||
|
global_out=self.global_proj_out(x_mean)
|
||||||
|
encoder_hidden_states = self.S(x,t,mask)
|
||||||
|
return encoder_hidden_states,global_out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return Qwen2ConnectorStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ConnectorStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("connector."):
|
||||||
|
name_ = name[len("connector."):]
|
||||||
|
state_dict_[name_] = param
|
||||||
|
return state_dict_
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||||
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||||
from ..prompters import FluxPrompter
|
from ..prompters import FluxPrompter
|
||||||
from ..schedulers import FlowMatchScheduler
|
from ..schedulers import FlowMatchScheduler
|
||||||
@@ -32,7 +33,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.ipadapter: FluxIpAdapter = None
|
self.ipadapter: FluxIpAdapter = None
|
||||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||||
self.infinityou_processor: InfinitYou = None
|
self.infinityou_processor: InfinitYou = None
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
self.qwenvl = None
|
||||||
|
self.step1x_connector: Qwen2Connector = None
|
||||||
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'step1x_connector']
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
@@ -167,6 +170,10 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||||
if self.image_proj_model is not None:
|
if self.image_proj_model is not None:
|
||||||
self.infinityou_processor = InfinitYou(device=self.device)
|
self.infinityou_processor = InfinitYou(device=self.device)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
self.qwenvl = model_manager.fetch_model("qwenvl")
|
||||||
|
self.step1x_connector = model_manager.fetch_model("step1x_connector")
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -191,10 +198,13 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
||||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
|
||||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||||
)
|
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
)
|
||||||
|
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||||
@@ -388,6 +398,17 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
else:
|
else:
|
||||||
flex_kwargs = {}
|
flex_kwargs = {}
|
||||||
return flex_kwargs
|
return flex_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
|
||||||
|
if image is None:
|
||||||
|
return {}, {}
|
||||||
|
captions = [prompt, negative_prompt]
|
||||||
|
ref_images = [image, image]
|
||||||
|
embs, masks = self.qwenvl(captions, ref_images)
|
||||||
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
image = self.encode_image(image)
|
||||||
|
return {"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}, {"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -432,6 +453,8 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
flex_control_image=None,
|
flex_control_image=None,
|
||||||
flex_control_strength=0.5,
|
flex_control_strength=0.5,
|
||||||
flex_control_stop=0.5,
|
flex_control_stop=0.5,
|
||||||
|
# Step1x
|
||||||
|
step1x_reference_image=None,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -472,7 +495,10 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||||
|
|
||||||
# Flex
|
# Flex
|
||||||
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs)
|
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength=flex_control_strength, flex_control_stop=flex_control_stop, **tiler_kwargs)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
step1x_kwargs_posi, step1x_kwargs_nega = self.prepare_step1x_kwargs(prompt, negative_prompt, image=step1x_reference_image)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||||
@@ -484,9 +510,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
# Positive side
|
# Positive side
|
||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_posi,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
@@ -501,9 +527,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
# Negative side
|
# Negative side
|
||||||
noise_pred_nega = lets_dance_flux(
|
noise_pred_nega = lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
|
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_nega,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -623,6 +649,7 @@ class TeaCache:
|
|||||||
def lets_dance_flux(
|
def lets_dance_flux(
|
||||||
dit: FluxDiT,
|
dit: FluxDiT,
|
||||||
controlnet: FluxMultiControlNetManager = None,
|
controlnet: FluxMultiControlNetManager = None,
|
||||||
|
step1x_connector: Qwen2Connector = None,
|
||||||
hidden_states=None,
|
hidden_states=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
prompt_emb=None,
|
prompt_emb=None,
|
||||||
@@ -642,6 +669,9 @@ def lets_dance_flux(
|
|||||||
flex_condition=None,
|
flex_condition=None,
|
||||||
flex_uncondition=None,
|
flex_uncondition=None,
|
||||||
flex_control_stop_timestep=None,
|
flex_control_stop_timestep=None,
|
||||||
|
step1x_llm_embedding=None,
|
||||||
|
step1x_mask=None,
|
||||||
|
step1x_reference_latents=None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -699,6 +729,11 @@ def lets_dance_flux(
|
|||||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||||
else:
|
else:
|
||||||
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
|
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
if step1x_llm_embedding is not None:
|
||||||
|
prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
|
||||||
|
text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
|
||||||
|
|
||||||
if image_ids is None:
|
if image_ids is None:
|
||||||
image_ids = dit.prepare_image_ids(hidden_states)
|
image_ids = dit.prepare_image_ids(hidden_states)
|
||||||
@@ -710,6 +745,14 @@ def lets_dance_flux(
|
|||||||
|
|
||||||
height, width = hidden_states.shape[-2:]
|
height, width = hidden_states.shape[-2:]
|
||||||
hidden_states = dit.patchify(hidden_states)
|
hidden_states = dit.patchify(hidden_states)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
if step1x_reference_latents is not None:
|
||||||
|
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
|
||||||
|
step1x_reference_latents = dit.patchify(step1x_reference_latents)
|
||||||
|
image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
|
||||||
|
hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
|
||||||
|
|
||||||
hidden_states = dit.x_embedder(hidden_states)
|
hidden_states = dit.x_embedder(hidden_states)
|
||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None:
|
if entity_prompt_emb is not None and entity_masks is not None:
|
||||||
@@ -764,6 +807,11 @@ def lets_dance_flux(
|
|||||||
|
|
||||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||||
hidden_states = dit.final_proj_out(hidden_states)
|
hidden_states = dit.final_proj_out(hidden_states)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
if step1x_reference_latents is not None:
|
||||||
|
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
|
||||||
|
|
||||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
34
examples/step1x/step1x.py
Normal file
34
examples/step1x/step1x.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import FluxImagePipeline, ModelManager
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
snapshot_download("Qwen/Qwen2.5-VL-7B-Instruct", cache_dir="./models")
|
||||||
|
snapshot_download("stepfun-ai/Step1X-Edit", cache_dir="./models")
|
||||||
|
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/Qwen/Qwen2.5-VL-7B-Instruct",
|
||||||
|
"models/stepfun-ai/Step1X-Edit/step1x-edit-i1258.safetensors",
|
||||||
|
"models/stepfun-ai/Step1X-Edit/vae.safetensors",
|
||||||
|
])
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
||||||
|
image = pipe(
|
||||||
|
prompt="draw red flowers in Chinese ink painting style",
|
||||||
|
step1x_reference_image=image,
|
||||||
|
width=832, height=1248, cfg_scale=6,
|
||||||
|
seed=1,
|
||||||
|
)
|
||||||
|
image.save("image_1.jpg")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="add more flowers in Chinese ink painting style",
|
||||||
|
step1x_reference_image=image,
|
||||||
|
width=832, height=1248, cfg_scale=6,
|
||||||
|
seed=2,
|
||||||
|
)
|
||||||
|
image.save("image_2.jpg")
|
||||||
Reference in New Issue
Block a user