diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index f70f3cb..6bb9350 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -62,6 +62,8 @@ from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_vace import VaceWanModel +from ..models.step1x_connector import Qwen2Connector + model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -136,6 +138,7 @@ model_loader_configs = [ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), + (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -151,6 +154,7 @@ huggingface_model_loader_configs = [ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"), ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"), ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"), + ("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"), ] patch_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/qwenvl.py b/diffsynth/models/qwenvl.py new file mode 100644 index 0000000..9677488 --- /dev/null +++ b/diffsynth/models/qwenvl.py @@ -0,0 +1,168 @@ +import torch + + +class Qwen25VL_7b_Embedder(torch.nn.Module): + def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"): + super(Qwen25VL_7b_Embedder, self).__init__() + self.max_length = max_length + self.dtype = dtype + self.device = device + + from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=dtype, + ).to(torch.cuda.current_device()) + + self.model.requires_grad_(False) + self.processor = AutoProcessor.from_pretrained( + model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28 + ) + + Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + + self.prefix = Qwen25VL_7b_PREFIX + + @staticmethod + def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"): + return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device) + + def forward(self, caption, ref_images): + text_list = caption + embs = torch.zeros( + len(text_list), + self.max_length, + self.model.config.hidden_size, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + ) + hidden_states = torch.zeros( + len(text_list), + self.max_length, + self.model.config.hidden_size, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + ) + masks = torch.zeros( + len(text_list), + self.max_length, + dtype=torch.long, + device=torch.cuda.current_device(), + ) + input_ids_list = [] + attention_mask_list = [] + emb_list = [] + + def split_string(s): + s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + + for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): + + messages = [{"role": "user", "content": []}] + + messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) + + messages[0]["content"].append({"type": "image", "image": imgs}) + + # 再添加 text + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + + image_inputs = [imgs] + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each = txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:, 1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = ( + torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) + .unsqueeze(0) + .to("cuda") + ) + inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + outputs = self.model( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + pixel_values=inputs.pixel_values.to("cuda"), + image_grid_thw=inputs.image_grid_thw.to("cuda"), + output_hidden_states=True, + ) + + emb = outputs["hidden_states"][-1] + + embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ + : self.max_length + ] + + masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( + (min(self.max_length, emb.shape[1] - 217)), + dtype=torch.long, + device=torch.cuda.current_device(), + ) + + return embs, masks \ No newline at end of file diff --git a/diffsynth/models/step1x_connector.py b/diffsynth/models/step1x_connector.py new file mode 100644 index 0000000..b4abe40 --- /dev/null +++ b/diffsynth/models/step1x_connector.py @@ -0,0 +1,683 @@ +from typing import Optional + +import torch, math +import torch.nn +from einops import rearrange +from torch import nn +from functools import partial +from einops import rearrange + + + +def attention(q, k, v, attn_mask, mode="torch"): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "b n s d -> b s (n d)") + return x + + + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = (bias, bias) + drop_probs = (drop, drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer( + in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype + ) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_channels, device=device, dtype=dtype) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer( + hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class TextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.linear_1 = nn.Linear( + in_features=in_channels, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + self.act_1 = act_layer() + self.linear_2 = nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, hidden_size, bias=True, **factory_kwargs + ), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore + nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding( + t, self.frequency_embedding_size, self.max_period + ).type(self.mlp[0].weight.dtype) # type: ignore + t_emb = self.mlp(t_freq) + return t_emb + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +def get_activation_layer(act_type): + """get activation layer + + Args: + act_type (str): the activation type + + Returns: + torch.nn.functional: the activation layer + """ + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + +class IndividualTokenRefinerBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + self.mlp = MLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + **factory_kwargs, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + + if self.need_CA: + self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs,) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + if self.need_CA: + x = self.cross_attnblock(x, c, attn_mask, y) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + + + +class CrossAttnBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.norm1_2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_q = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + self.self_attn_kv = nn.Linear( + hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor=None, + + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + norm_y = self.norm1_2(y) + q = self.self_attn_q(norm_x) + q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) + kv = self.self_attn_kv(norm_y) + k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + return x + + + +class IndividualTokenRefiner(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=self.need_CA, + **factory_kwargs, + ) + for _ in range(depth) + ] + ) + + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + y:torch.Tensor=None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( + 1, 1, seq_len, 1 + ) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + + for block in self.blocks: + x = block(x, c, self_attn_mask,y) + + return x + + +class SingleTokenRefiner(torch.nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + self.need_CA = need_CA + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + if self.need_CA: + self.input_embedder_CA = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + + act_layer = get_activation_layer(act_type) + # Build timestep embedding layer + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + # Build context embedding layer + self.c_embedder = TextProjection( + in_channels, hidden_size, act_layer, **factory_kwargs + ) + + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=need_CA, + **factory_kwargs, + ) + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + y: torch.LongTensor=None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + if self.need_CA: + y = self.input_embedder_CA(y) + x = self.individual_token_refiner(x, c, mask, y) + else: + x = self.individual_token_refiner(x, c, mask) + + return x + + +class Qwen2Connector(torch.nn.Module): + def __init__( + self, + # biclip_dim=1024, + in_channels=3584, + hidden_size=4096, + heads_num=32, + depth=2, + need_CA=False, + device=None, + dtype=torch.bfloat16, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype":dtype} + + self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs) + self.global_proj_out=nn.Linear(in_channels,768) + + self.scale_factor = nn.Parameter(torch.zeros(1)) + with torch.no_grad(): + self.scale_factor.data += -(1 - 0.09) + + def forward(self, x,t,mask): + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + x_mean = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) * (1 + self.scale_factor) + + global_out=self.global_proj_out(x_mean) + encoder_hidden_states = self.S(x,t,mask) + return encoder_hidden_states,global_out + + @staticmethod + def state_dict_converter(): + return Qwen2ConnectorStateDictConverter() + + +class Qwen2ConnectorStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith("connector."): + name_ = name[len("connector."):] + state_dict_[name_] = param + return state_dict_ diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index c17e182..0abe26f 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -1,4 +1,5 @@ from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter +from ..models.step1x_connector import Qwen2Connector from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..prompters import FluxPrompter from ..schedulers import FlowMatchScheduler @@ -32,7 +33,9 @@ class FluxImagePipeline(BasePipeline): self.ipadapter: FluxIpAdapter = None self.ipadapter_image_encoder: SiglipVisionModel = None self.infinityou_processor: InfinitYou = None - self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder'] + self.qwenvl = None + self.step1x_connector: Qwen2Connector = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'step1x_connector'] def enable_vram_management(self, num_persistent_param_in_dit=None): @@ -167,6 +170,10 @@ class FluxImagePipeline(BasePipeline): self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector") if self.image_proj_model is not None: self.infinityou_processor = InfinitYou(device=self.device) + + # Step1x + self.qwenvl = model_manager.fetch_model("qwenvl") + self.step1x_connector = model_manager.fetch_model("step1x_connector") @staticmethod @@ -191,10 +198,13 @@ class FluxImagePipeline(BasePipeline): def encode_prompt(self, prompt, positive=True, t5_sequence_length=512): - prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( - prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length - ) - return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} + if self.text_encoder_1 is not None and self.text_encoder_2 is not None: + prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( + prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length + ) + return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} + else: + return {} def prepare_extra_input(self, latents=None, guidance=1.0): @@ -388,6 +398,17 @@ class FluxImagePipeline(BasePipeline): else: flex_kwargs = {} return flex_kwargs + + + def prepare_step1x_kwargs(self, prompt, negative_prompt, image): + if image is None: + return {}, {} + captions = [prompt, negative_prompt] + ref_images = [image, image] + embs, masks = self.qwenvl(captions, ref_images) + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + image = self.encode_image(image) + return {"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}, {"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image} @torch.no_grad() @@ -432,6 +453,8 @@ class FluxImagePipeline(BasePipeline): flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, + # Step1x + step1x_reference_image=None, # TeaCache tea_cache_l1_thresh=None, # Tile @@ -472,7 +495,10 @@ class FluxImagePipeline(BasePipeline): controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) # Flex - flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs) + flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength=flex_control_strength, flex_control_stop=flex_control_stop, **tiler_kwargs) + + # Step1x + step1x_kwargs_posi, step1x_kwargs_nega = self.prepare_step1x_kwargs(prompt, negative_prompt, image=step1x_reference_image) # TeaCache tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} @@ -484,9 +510,9 @@ class FluxImagePipeline(BasePipeline): # Positive side inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( - dit=self.dit, controlnet=self.controlnet, + dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, @@ -501,9 +527,9 @@ class FluxImagePipeline(BasePipeline): if cfg_scale != 1.0: # Negative side noise_pred_nega = lets_dance_flux( - dit=self.dit, controlnet=self.controlnet, + dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -623,6 +649,7 @@ class TeaCache: def lets_dance_flux( dit: FluxDiT, controlnet: FluxMultiControlNetManager = None, + step1x_connector: Qwen2Connector = None, hidden_states=None, timestep=None, prompt_emb=None, @@ -642,6 +669,9 @@ def lets_dance_flux( flex_condition=None, flex_uncondition=None, flex_control_stop_timestep=None, + step1x_llm_embedding=None, + step1x_mask=None, + step1x_reference_latents=None, tea_cache: TeaCache = None, **kwargs ): @@ -699,6 +729,11 @@ def lets_dance_flux( hidden_states = torch.concat([hidden_states, flex_condition], dim=1) else: hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) + + # Step1x + if step1x_llm_embedding is not None: + prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask) + text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device) if image_ids is None: image_ids = dit.prepare_image_ids(hidden_states) @@ -710,6 +745,14 @@ def lets_dance_flux( height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents) + step1x_reference_latents = dit.patchify(step1x_reference_latents) + image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1) + hidden_states = dit.x_embedder(hidden_states) if entity_prompt_emb is not None and entity_masks is not None: @@ -764,6 +807,11 @@ def lets_dance_flux( hidden_states = dit.final_norm_out(hidden_states, conditioning) hidden_states = dit.final_proj_out(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + hidden_states = hidden_states[:, :hidden_states.shape[1] // 2] + hidden_states = dit.unpatchify(hidden_states, height, width) return hidden_states diff --git a/examples/step1x/step1x.py b/examples/step1x/step1x.py new file mode 100644 index 0000000..6ca974d --- /dev/null +++ b/examples/step1x/step1x.py @@ -0,0 +1,34 @@ +import torch +from diffsynth import FluxImagePipeline, ModelManager +from modelscope import snapshot_download +from PIL import Image +import numpy as np + + +snapshot_download("Qwen/Qwen2.5-VL-7B-Instruct", cache_dir="./models") +snapshot_download("stepfun-ai/Step1X-Edit", cache_dir="./models") + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "models/Qwen/Qwen2.5-VL-7B-Instruct", + "models/stepfun-ai/Step1X-Edit/step1x-edit-i1258.safetensors", + "models/stepfun-ai/Step1X-Edit/vae.safetensors", +]) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255) +image = pipe( + prompt="draw red flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=1, +) +image.save("image_1.jpg") + +image = pipe( + prompt="add more flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=2, +) +image.save("image_2.jpg")