mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
nexus-gen
This commit is contained in:
@@ -202,10 +202,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
||||
if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, image_emb=None):
|
||||
if (self.text_encoder_1 is not None and self.text_encoder_2 is not None) or (image_emb is not None):
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, image_emb=image_emb
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
else:
|
||||
@@ -358,13 +358,13 @@ class FluxImagePipeline(BasePipeline):
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
|
||||
|
||||
|
||||
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
|
||||
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb=None):
|
||||
# Extend prompt
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
|
||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length, image_emb=image_emb)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
@@ -432,6 +432,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=None,
|
||||
image_emb=None,
|
||||
# Steps
|
||||
num_inference_steps=30,
|
||||
# local prompts
|
||||
@@ -483,7 +484,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
|
||||
|
||||
# Prompt
|
||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
|
||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
@@ -59,6 +59,7 @@ class FluxPrompter(BasePrompter):
|
||||
positive=True,
|
||||
device="cuda",
|
||||
t5_sequence_length=512,
|
||||
image_emb=None,
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
@@ -66,7 +67,10 @@ class FluxPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
|
||||
# T5
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
||||
if image_emb is not None:
|
||||
prompt_emb = image_emb
|
||||
else:
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
||||
|
||||
# text_ids
|
||||
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
|
||||
|
||||
0
modeling/ar/__init__.py
Normal file
0
modeling/ar/__init__.py
Normal file
258
modeling/ar/configuration_qwen2_5_vl.py
Normal file
258
modeling/ar/configuration_qwen2_5_vl.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_qwen2_5_vl.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
hidden_size=3584,
|
||||
hidden_act="silu",
|
||||
intermediate_size=3420,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
tokens_per_second=4,
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.tokens_per_second = tokens_per_second
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 29568):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 80):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2_5_VL style configuration
|
||||
>>> configuration = Qwen2_5_VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_vl"
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2_5_VL`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
# TODO: @raushan update config in the hub
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLConfig"]
|
||||
2354
modeling/ar/modeling_qwen2_5_vl.py
Normal file
2354
modeling/ar/modeling_qwen2_5_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
235
modeling/ar/processing_qwen2_5_vl.py
Normal file
235
modeling/ar/processing_qwen2_5_vl.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput, VideoInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
|
||||
fps: Union[List[float], float]
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"videos_kwargs": {"fps": 2.0},
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
|
||||
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
||||
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
||||
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
||||
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
||||
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Qwen2_5_VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
|
||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
||||
if isinstance(fps, (int, float)):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)
|
||||
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
||||
)
|
||||
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
||||
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
if image_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.image_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.image_token,
|
||||
"<|placeholder|>" * (image_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.video_token,
|
||||
"<|placeholder|>" * (video_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def batch_decode_all2all(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
decoded = self.tokenizer.batch_decode(*args, **kwargs)
|
||||
pattern = r'<\|vision_start\|>.*?<\|vision_end\|>'
|
||||
decoded_with_image_tag = [re.sub(pattern, '<image>', d, flags=re.DOTALL) for d in decoded]
|
||||
decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag]
|
||||
return decoded_with_image_tag
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
return names_from_processor + ["second_per_grid_ts"]
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLProcessor"]
|
||||
64
modeling/decoder/flux_decoder.py
Normal file
64
modeling/decoder/flux_decoder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager
|
||||
from .flux_image_pipeline import FluxImagePipelineAll2All
|
||||
|
||||
class FluxDecoder:
|
||||
|
||||
def __init__(self, flux_all2all_modelpath, flux_path, device='cuda', torch_dtype=torch.bfloat16):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.pipe, self.adapter = self.get_pipe(flux_all2all_modelpath, flux_path, device, torch_dtype)
|
||||
|
||||
def get_pipe(self, flux_all2all_modelpath, flux_path, device="cuda", torch_dtype=torch.bfloat16):
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=device)
|
||||
model_manager.load_models([
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
|
||||
state_dict = torch.load(flux_all2all_modelpath, weights_only=True, map_location='cpu')
|
||||
adapter_states = ['0.weight', '0.bias', '1.weight', '1.bias', '3.weight', '3.bias', '4.weight', '4.bias']
|
||||
adapter_state_dict = {}
|
||||
for key in adapter_states:
|
||||
adapter_state_dict[key] = state_dict.pop(key)
|
||||
|
||||
in_channel = 3584
|
||||
out_channel = 4096
|
||||
expand_ratio = 1
|
||||
adapter = torch.nn.Sequential(torch.nn.Linear(in_channel, out_channel * expand_ratio),
|
||||
torch.nn.LayerNorm(out_channel * expand_ratio), torch.nn.ReLU(),
|
||||
torch.nn.Linear(out_channel * expand_ratio, out_channel),
|
||||
torch.nn.LayerNorm(out_channel))
|
||||
adapter.load_state_dict(adapter_state_dict)
|
||||
adapter.to(device, dtype=torch_dtype)
|
||||
|
||||
pipe = FluxImagePipelineAll2All.from_model_manager(model_manager)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
return pipe, adapter
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_image_embeds(self,
|
||||
output_image_embeddings,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=50,
|
||||
seed=42,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
**pipe_kwargs):
|
||||
output_image_embeddings = output_image_embeddings.to(device=self.device, dtype=self.torch_dtype)
|
||||
image_embed = self.adapter(output_image_embeddings)
|
||||
image = self.pipe(prompt="",
|
||||
image_embed=image_embed,
|
||||
num_inference_steps=num_inference_steps,
|
||||
embedded_guidance=3.5,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=cfg_scale,
|
||||
height=height,
|
||||
width=width,
|
||||
seed=seed,
|
||||
**pipe_kwargs)
|
||||
return image
|
||||
192
modeling/decoder/flux_image_pipeline.py
Normal file
192
modeling/decoder/flux_image_pipeline.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.controlnets import ControlNetConfigUnit
|
||||
from diffsynth.prompters.flux_prompter import FluxPrompter
|
||||
from diffsynth.pipelines.flux_image import FluxImagePipeline, lets_dance_flux, TeaCache
|
||||
|
||||
|
||||
class FluxPrompterAll2All(FluxPrompter):
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
t5_sequence_length=512,
|
||||
clip_only=False
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
# CLIP
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
if clip_only:
|
||||
return None, pooled_prompt_emb, None
|
||||
# T5
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
||||
# text_ids
|
||||
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
|
||||
return prompt_emb, pooled_prompt_emb, text_ids
|
||||
|
||||
|
||||
class FluxImagePipelineAll2All(FluxImagePipeline):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.prompter = FluxPrompterAll2All()
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, clip_only=False):
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, clip_only=clip_only
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
pipe = FluxImagePipelineAll2All(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def prepare_prompts(self, prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
|
||||
# Extend prompt
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||
|
||||
# Encode prompts
|
||||
if image_embed is not None:
|
||||
image_embed = image_embed.to(self.torch_dtype)
|
||||
prompt_emb_posi = self.encode_prompt("", positive=True, clip_only=True)
|
||||
if len(image_embed.size()) == 2:
|
||||
image_embed = image_embed.unsqueeze(0)
|
||||
prompt_emb_posi['prompt_emb'] = image_embed
|
||||
prompt_emb_posi['text_ids'] = torch.zeros(image_embed.shape[0], image_embed.shape[1], 3).to(device=self.device, dtype=self.torch_dtype)
|
||||
else:
|
||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=3.5,
|
||||
t5_sequence_length=512,
|
||||
# Image
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=None,
|
||||
# image_embed
|
||||
image_embed=None,
|
||||
# Steps
|
||||
num_inference_steps=30,
|
||||
# local prompts
|
||||
local_prompts=(),
|
||||
masks=(),
|
||||
mask_scales=(),
|
||||
# ControlNet
|
||||
controlnet_image=None,
|
||||
controlnet_inpaint_mask=None,
|
||||
enable_controlnet_on_negative=False,
|
||||
# IP-Adapter
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
# EliGen
|
||||
eligen_entity_prompts=None,
|
||||
eligen_entity_masks=None,
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=False,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
|
||||
|
||||
# Prompt
|
||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Entity control
|
||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||
|
||||
# IP-Adapter
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit', 'controlnet'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Positive side
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
|
||||
)
|
||||
|
||||
# Inpaint
|
||||
if enable_eligen_inpaint:
|
||||
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
|
||||
|
||||
# Classifier-free guidance
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Iterate
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, **tiler_kwargs)
|
||||
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
85
test.py
Normal file
85
test.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
import torch
|
||||
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict
|
||||
from qwen_vl_utils import smart_resize
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class NexusGenQwenVLEncoder(torch.nn.Module):
|
||||
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
|
||||
super().__init__()
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
|
||||
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
|
||||
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
|
||||
|
||||
def process_images(self, images=None):
|
||||
if images is None:
|
||||
return None
|
||||
# resize input to max_pixels to avoid oom
|
||||
for j in range(len(images)):
|
||||
input_image = images[j]
|
||||
input_w, input_h = input_image.size
|
||||
resized_height, resized_width = smart_resize(
|
||||
input_h,
|
||||
input_w,
|
||||
max_pixels=262640,
|
||||
)
|
||||
images[j] = input_image.resize((resized_width, resized_height))
|
||||
return images
|
||||
|
||||
def forward(self, prompt, images=None):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},],
|
||||
}]
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=self.process_images(images),
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
generation_image_grid_thw = torch.tensor([[1, 18, 18]]).to(self.model.device)
|
||||
outputs = self.model.generate(**inputs,
|
||||
max_new_tokens=1024,
|
||||
return_dict_in_generate=True,
|
||||
generation_image_grid_thw=generation_image_grid_thw)
|
||||
output_image_embeddings = outputs['output_image_embeddings']
|
||||
return output_image_embeddings
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
pipe.dit.load_state_dict(load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16), strict=False)
|
||||
|
||||
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
|
||||
|
||||
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
|
||||
adapter.load_state_dict(load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16), strict=False)
|
||||
|
||||
with torch.no_grad():
|
||||
instruction = "<|vision_start|><|image_pad|><|vision_end|> Transform the style to flat anime. Keep the color."
|
||||
emb = qwenvl(instruction, images=[Image.open("image_3.jpg").convert('RGB')])
|
||||
emb = adapter(emb)
|
||||
image = pipe("", image_emb=emb)
|
||||
image.save("image_4.jpg")
|
||||
Reference in New Issue
Block a user