mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
7 Commits
vram-bugfi
...
nexus-gen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b72ae0e56 | ||
|
|
90588dcf97 | ||
|
|
91fbb24e17 | ||
|
|
f17558a4c4 | ||
|
|
290ec469ca | ||
|
|
1ed676b076 | ||
|
|
f7737aff98 |
@@ -42,6 +42,12 @@ Until now, DiffSynth-Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **May 1, 2025** 🔥🔥🔥 We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models.
|
||||
- Paper: [Nexus-Gen: A Unified Model for Image Understanding, Generation, and Editing](https://arxiv.org/pdf/2504.21356)
|
||||
- Github Repo: https://github.com/modelscope/Nexus-Gen
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-Gen), [HuggingFace](https://huggingface.co/modelscope/Nexus-Gen)
|
||||
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
|
||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
@@ -202,10 +202,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
||||
if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, image_emb=None):
|
||||
if (self.text_encoder_1 is not None and self.text_encoder_2 is not None) or (image_emb is not None):
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, image_emb=image_emb
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
else:
|
||||
@@ -358,13 +358,13 @@ class FluxImagePipeline(BasePipeline):
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
|
||||
|
||||
|
||||
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
|
||||
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb=None):
|
||||
# Extend prompt
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
|
||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length, image_emb=image_emb)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
@@ -432,6 +432,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=None,
|
||||
image_emb=None,
|
||||
# Steps
|
||||
num_inference_steps=30,
|
||||
# local prompts
|
||||
@@ -483,7 +484,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
|
||||
|
||||
# Prompt
|
||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
|
||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
@@ -679,6 +680,7 @@ def lets_dance_flux(
|
||||
step1x_mask=None,
|
||||
step1x_reference_latents=None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
@@ -773,20 +775,32 @@ def lets_dance_flux(
|
||||
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
if tea_cache_update:
|
||||
hidden_states = tea_cache.update(hidden_states)
|
||||
else:
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
@@ -795,14 +809,21 @@ def lets_dance_flux(
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
|
||||
@@ -59,6 +59,7 @@ class FluxPrompter(BasePrompter):
|
||||
positive=True,
|
||||
device="cuda",
|
||||
t5_sequence_length=512,
|
||||
image_emb=None,
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
@@ -66,7 +67,10 @@ class FluxPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
|
||||
# T5
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
||||
if image_emb is not None:
|
||||
prompt_emb = image_emb
|
||||
else:
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
||||
|
||||
# text_ids
|
||||
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
|
||||
|
||||
0
modeling/ar/__init__.py
Normal file
0
modeling/ar/__init__.py
Normal file
258
modeling/ar/configuration_qwen2_5_vl.py
Normal file
258
modeling/ar/configuration_qwen2_5_vl.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_qwen2_5_vl.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
hidden_size=3584,
|
||||
hidden_act="silu",
|
||||
intermediate_size=3420,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
tokens_per_second=4,
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.tokens_per_second = tokens_per_second
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 29568):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 80):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2_5_VL style configuration
|
||||
>>> configuration = Qwen2_5_VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_vl"
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2_5_VL`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
# TODO: @raushan update config in the hub
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLConfig"]
|
||||
2371
modeling/ar/modeling_qwen2_5_vl.py
Normal file
2371
modeling/ar/modeling_qwen2_5_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
235
modeling/ar/processing_qwen2_5_vl.py
Normal file
235
modeling/ar/processing_qwen2_5_vl.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput, VideoInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
|
||||
fps: Union[List[float], float]
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"videos_kwargs": {"fps": 2.0},
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
|
||||
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
||||
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
||||
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
||||
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
||||
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Qwen2_5_VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
|
||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
||||
if isinstance(fps, (int, float)):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)
|
||||
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
||||
)
|
||||
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
||||
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
if image_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.image_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.image_token,
|
||||
"<|placeholder|>" * (image_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.video_token,
|
||||
"<|placeholder|>" * (video_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def batch_decode_all2all(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
decoded = self.tokenizer.batch_decode(*args, **kwargs)
|
||||
pattern = r'<\|vision_start\|>.*?<\|vision_end\|>'
|
||||
decoded_with_image_tag = [re.sub(pattern, '<image>', d, flags=re.DOTALL) for d in decoded]
|
||||
decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag]
|
||||
return decoded_with_image_tag
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
return names_from_processor + ["second_per_grid_ts"]
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLProcessor"]
|
||||
64
modeling/decoder/flux_decoder.py
Normal file
64
modeling/decoder/flux_decoder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager
|
||||
from .flux_image_pipeline import FluxImagePipelineAll2All
|
||||
|
||||
class FluxDecoder:
|
||||
|
||||
def __init__(self, flux_all2all_modelpath, flux_path, device='cuda', torch_dtype=torch.bfloat16):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.pipe, self.adapter = self.get_pipe(flux_all2all_modelpath, flux_path, device, torch_dtype)
|
||||
|
||||
def get_pipe(self, flux_all2all_modelpath, flux_path, device="cuda", torch_dtype=torch.bfloat16):
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=device)
|
||||
model_manager.load_models([
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
|
||||
state_dict = torch.load(flux_all2all_modelpath, weights_only=True, map_location='cpu')
|
||||
adapter_states = ['0.weight', '0.bias', '1.weight', '1.bias', '3.weight', '3.bias', '4.weight', '4.bias']
|
||||
adapter_state_dict = {}
|
||||
for key in adapter_states:
|
||||
adapter_state_dict[key] = state_dict.pop(key)
|
||||
|
||||
in_channel = 3584
|
||||
out_channel = 4096
|
||||
expand_ratio = 1
|
||||
adapter = torch.nn.Sequential(torch.nn.Linear(in_channel, out_channel * expand_ratio),
|
||||
torch.nn.LayerNorm(out_channel * expand_ratio), torch.nn.ReLU(),
|
||||
torch.nn.Linear(out_channel * expand_ratio, out_channel),
|
||||
torch.nn.LayerNorm(out_channel))
|
||||
adapter.load_state_dict(adapter_state_dict)
|
||||
adapter.to(device, dtype=torch_dtype)
|
||||
|
||||
pipe = FluxImagePipelineAll2All.from_model_manager(model_manager)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
return pipe, adapter
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_image_embeds(self,
|
||||
output_image_embeddings,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=50,
|
||||
seed=42,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
**pipe_kwargs):
|
||||
output_image_embeddings = output_image_embeddings.to(device=self.device, dtype=self.torch_dtype)
|
||||
image_embed = self.adapter(output_image_embeddings)
|
||||
image = self.pipe(prompt="",
|
||||
image_embed=image_embed,
|
||||
num_inference_steps=num_inference_steps,
|
||||
embedded_guidance=3.5,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=cfg_scale,
|
||||
height=height,
|
||||
width=width,
|
||||
seed=seed,
|
||||
**pipe_kwargs)
|
||||
return image
|
||||
192
modeling/decoder/flux_image_pipeline.py
Normal file
192
modeling/decoder/flux_image_pipeline.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.controlnets import ControlNetConfigUnit
|
||||
from diffsynth.prompters.flux_prompter import FluxPrompter
|
||||
from diffsynth.pipelines.flux_image import FluxImagePipeline, lets_dance_flux, TeaCache
|
||||
|
||||
|
||||
class FluxPrompterAll2All(FluxPrompter):
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
t5_sequence_length=512,
|
||||
clip_only=False
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
# CLIP
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
if clip_only:
|
||||
return None, pooled_prompt_emb, None
|
||||
# T5
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
||||
# text_ids
|
||||
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
|
||||
return prompt_emb, pooled_prompt_emb, text_ids
|
||||
|
||||
|
||||
class FluxImagePipelineAll2All(FluxImagePipeline):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.prompter = FluxPrompterAll2All()
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, clip_only=False):
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, clip_only=clip_only
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
pipe = FluxImagePipelineAll2All(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def prepare_prompts(self, prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
|
||||
# Extend prompt
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||
|
||||
# Encode prompts
|
||||
if image_embed is not None:
|
||||
image_embed = image_embed.to(self.torch_dtype)
|
||||
prompt_emb_posi = self.encode_prompt("", positive=True, clip_only=True)
|
||||
if len(image_embed.size()) == 2:
|
||||
image_embed = image_embed.unsqueeze(0)
|
||||
prompt_emb_posi['prompt_emb'] = image_embed
|
||||
prompt_emb_posi['text_ids'] = torch.zeros(image_embed.shape[0], image_embed.shape[1], 3).to(device=self.device, dtype=self.torch_dtype)
|
||||
else:
|
||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=3.5,
|
||||
t5_sequence_length=512,
|
||||
# Image
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=None,
|
||||
# image_embed
|
||||
image_embed=None,
|
||||
# Steps
|
||||
num_inference_steps=30,
|
||||
# local prompts
|
||||
local_prompts=(),
|
||||
masks=(),
|
||||
mask_scales=(),
|
||||
# ControlNet
|
||||
controlnet_image=None,
|
||||
controlnet_inpaint_mask=None,
|
||||
enable_controlnet_on_negative=False,
|
||||
# IP-Adapter
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
# EliGen
|
||||
eligen_entity_prompts=None,
|
||||
eligen_entity_masks=None,
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=False,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
|
||||
|
||||
# Prompt
|
||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Entity control
|
||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||
|
||||
# IP-Adapter
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit', 'controlnet'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Positive side
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
|
||||
)
|
||||
|
||||
# Inpaint
|
||||
if enable_eligen_inpaint:
|
||||
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
|
||||
|
||||
# Classifier-free guidance
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Iterate
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, **tiler_kwargs)
|
||||
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
@@ -1,7 +1,7 @@
|
||||
torch>=2.0.0
|
||||
torchvision
|
||||
cupy-cuda12x
|
||||
transformers==4.46.2
|
||||
transformers==4.49.0
|
||||
controlnet-aux==0.0.7
|
||||
imageio
|
||||
imageio[ffmpeg]
|
||||
@@ -11,3 +11,4 @@ sentencepiece
|
||||
protobuf
|
||||
modelscope
|
||||
ftfy
|
||||
qwen_vl_utils
|
||||
|
||||
4
run_single.sh
Normal file
4
run_single.sh
Normal file
@@ -0,0 +1,4 @@
|
||||
accelerate launch \
|
||||
train.py \
|
||||
--output_path models/nexus_v3 \
|
||||
--steps_per_epoch 4000
|
||||
312
test.py
Normal file
312
test.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
import torch, json, os, torchvision
|
||||
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
|
||||
from qwen_vl_utils import smart_resize
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision.transforms import v2
|
||||
|
||||
|
||||
|
||||
class SingleTaskDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path,
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None
|
||||
):
|
||||
self.base_path = base_path
|
||||
self.keys = keys
|
||||
self.metadata = []
|
||||
self.bad_data = []
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.random = random
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.image_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
if metadata_path is None:
|
||||
self.search_for_data("", report_data_log=True)
|
||||
self.report_data_log()
|
||||
else:
|
||||
with open(metadata_path, "r", encoding="utf-8-sig") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
|
||||
def report_data_log(self):
|
||||
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
|
||||
|
||||
|
||||
def dump_metadata(self, path):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def parse_json_file(self, absolute_path, relative_path):
|
||||
data_list = []
|
||||
with open(absolute_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
for image_1, image_2, instruction in self.keys:
|
||||
image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None
|
||||
image_2 = os.path.join(relative_path, metadata[image_2])
|
||||
instruction = metadata[instruction]
|
||||
data_list.append((image_1, image_2, instruction))
|
||||
return data_list
|
||||
|
||||
|
||||
def search_for_data(self, path, report_data_log=False):
|
||||
now_path = os.path.join(self.base_path, path)
|
||||
if os.path.isfile(now_path) and path.endswith(".json"):
|
||||
try:
|
||||
data_list = self.parse_json_file(now_path, os.path.dirname(path))
|
||||
self.metadata.extend(data_list)
|
||||
except:
|
||||
self.bad_data.append(now_path)
|
||||
elif os.path.isdir(now_path):
|
||||
for sub_path in os.listdir(now_path):
|
||||
self.search_for_data(os.path.join(path, sub_path))
|
||||
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
|
||||
self.report_data_log()
|
||||
|
||||
|
||||
def load_image(self, image_path, skip_process=False):
|
||||
image_path = os.path.join(self.base_path, image_path)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
width, height = image.size
|
||||
scale = max(self.width / width, self.height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
if skip_process:
|
||||
return image
|
||||
image = self.image_process(image)
|
||||
return image
|
||||
|
||||
|
||||
def load_data(self, data_id):
|
||||
image_1, image_2, instruction = self.metadata[data_id]
|
||||
image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None
|
||||
image_2 = self.load_image(image_2)
|
||||
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
if self.random:
|
||||
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
|
||||
data = self.load_data(data_id)
|
||||
return data
|
||||
else:
|
||||
return self.load_data(data_id)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch if self.random else len(self.metadata)
|
||||
|
||||
|
||||
|
||||
class MultiTaskDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
|
||||
self.dataset_list = dataset_list
|
||||
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
|
||||
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
|
||||
data = self.dataset_list[dataset_id][data_id]
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class NexusGenQwenVLEncoder(torch.nn.Module):
|
||||
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
|
||||
super().__init__()
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
|
||||
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
|
||||
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
|
||||
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
|
||||
|
||||
def process_images(self, images=None):
|
||||
if images is None:
|
||||
return None
|
||||
# resize input to max_pixels to avoid oom
|
||||
for j in range(len(images)):
|
||||
input_image = images[j]
|
||||
input_w, input_h = input_image.size
|
||||
resized_height, resized_width = smart_resize(
|
||||
input_h,
|
||||
input_w,
|
||||
max_pixels=262640,
|
||||
)
|
||||
images[j] = input_image.resize((resized_width, resized_height))
|
||||
return images
|
||||
|
||||
def forward(self, prompt, images=None, num_img_tokens=81):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": self.t2i_template if images is None else self.i2i_template
|
||||
},],
|
||||
}
|
||||
]
|
||||
images = self.process_images(images)
|
||||
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
|
||||
if images is None:
|
||||
images = [target_image]
|
||||
else:
|
||||
images = images + [target_image]
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
|
||||
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||
|
||||
image_mask = inputs['input_ids'] == self.model.config.image_token_id
|
||||
indices = image_mask.cumsum(dim=1)
|
||||
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||
|
||||
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
|
||||
inputs['image_grid_thw'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
position_ids = position_ids.contiguous()
|
||||
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
|
||||
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||
output_image_embeddings = output_image_embeddings.unsqueeze(0)
|
||||
return output_image_embeddings
|
||||
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
|
||||
# pipe.dit.load_state_dict(state_dict, strict=False)
|
||||
|
||||
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
|
||||
# adapter.load_state_dict(state_dict, strict=False)
|
||||
|
||||
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
|
||||
|
||||
sd = {}
|
||||
for i in range(1, 6):
|
||||
print(i)
|
||||
sd.update(load_state_dict(f"models/nexus_v3/epoch-19/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16))
|
||||
pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")})
|
||||
qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")})
|
||||
adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")})
|
||||
|
||||
|
||||
dataset = MultiTaskDataset(
|
||||
dataset_list=[
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
|
||||
),
|
||||
],
|
||||
dataset_weight=(4, 2, 1,),
|
||||
steps_per_epoch=100000
|
||||
)
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
for data_id, data in enumerate(dataset):
|
||||
image_1 = data["image_1"]
|
||||
image_2 = data["image_2"].cpu().float().permute(1, 2, 0).numpy()
|
||||
image_2 = Image.fromarray(((image_2 / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
instruction = data["instruction"]
|
||||
|
||||
print(instruction)
|
||||
if image_1 is None:
|
||||
with torch.no_grad():
|
||||
instruction = f"Generate an image according to the following description: {instruction}"
|
||||
emb = qwenvl(instruction, images=None)
|
||||
emb = adapter(emb)
|
||||
image_3 = pipe("", image_emb=emb)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
instruction = f"<|vision_start|><|image_pad|><|vision_end|> {instruction}"
|
||||
emb = qwenvl(instruction, images=[image_1])
|
||||
emb = adapter(emb)
|
||||
image_3 = pipe("", image_emb=emb)
|
||||
|
||||
if image_1 is not None:
|
||||
image_1.save(f"data/output/{data_id}_1.jpg")
|
||||
image_2.save(f"data/output/{data_id}_2.jpg")
|
||||
image_3.save(f"data/output/{data_id}_3.jpg")
|
||||
if data_id >= 100:
|
||||
break
|
||||
|
||||
|
||||
|
||||
# with torch.no_grad():
|
||||
# instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic"
|
||||
# emb = qwenvl(instruction, images=None)
|
||||
# emb = adapter(emb)
|
||||
# image = pipe("", image_emb=emb)
|
||||
# image.save("image_1.jpg")
|
||||
|
||||
# with torch.no_grad():
|
||||
# instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression."
|
||||
# emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
|
||||
# emb = adapter(emb)
|
||||
# image = pipe("", image_emb=emb)
|
||||
# image.save("image_2.jpg")
|
||||
421
train.py
Normal file
421
train.py
Normal file
@@ -0,0 +1,421 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict
|
||||
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
import torch, os, argparse
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from accelerate import Accelerator
|
||||
from tqdm import tqdm
|
||||
import torch, os, json, torchvision
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
import torch
|
||||
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
|
||||
from qwen_vl_utils import smart_resize
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import lightning as pl
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||
|
||||
|
||||
|
||||
class SingleTaskDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path,
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None
|
||||
):
|
||||
self.base_path = base_path
|
||||
self.keys = keys
|
||||
self.metadata = []
|
||||
self.bad_data = []
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.random = random
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.image_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
if metadata_path is None:
|
||||
self.search_for_data("", report_data_log=True)
|
||||
self.report_data_log()
|
||||
else:
|
||||
with open(metadata_path, "r", encoding="utf-8-sig") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
|
||||
def report_data_log(self):
|
||||
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
|
||||
|
||||
|
||||
def dump_metadata(self, path):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def parse_json_file(self, absolute_path, relative_path):
|
||||
data_list = []
|
||||
with open(absolute_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
for image_1, image_2, instruction in self.keys:
|
||||
image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None
|
||||
image_2 = os.path.join(relative_path, metadata[image_2])
|
||||
instruction = metadata[instruction]
|
||||
data_list.append((image_1, image_2, instruction))
|
||||
return data_list
|
||||
|
||||
|
||||
def search_for_data(self, path, report_data_log=False):
|
||||
now_path = os.path.join(self.base_path, path)
|
||||
if os.path.isfile(now_path) and path.endswith(".json"):
|
||||
try:
|
||||
data_list = self.parse_json_file(now_path, os.path.dirname(path))
|
||||
self.metadata.extend(data_list)
|
||||
except:
|
||||
self.bad_data.append(now_path)
|
||||
elif os.path.isdir(now_path):
|
||||
for sub_path in os.listdir(now_path):
|
||||
self.search_for_data(os.path.join(path, sub_path))
|
||||
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
|
||||
self.report_data_log()
|
||||
|
||||
|
||||
def load_image(self, image_path, skip_process=False):
|
||||
image_path = os.path.join(self.base_path, image_path)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
width, height = image.size
|
||||
scale = max(self.width / width, self.height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
if skip_process:
|
||||
return image
|
||||
image = self.image_process(image)
|
||||
return image
|
||||
|
||||
|
||||
def load_data(self, data_id):
|
||||
image_1, image_2, instruction = self.metadata[data_id]
|
||||
image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None
|
||||
image_2 = self.load_image(image_2)
|
||||
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
if self.random:
|
||||
while True:
|
||||
try:
|
||||
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
|
||||
data = self.load_data(data_id)
|
||||
return data
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
return self.load_data(data_id)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch if self.random else len(self.metadata)
|
||||
|
||||
|
||||
|
||||
class MultiTaskDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
|
||||
self.dataset_list = dataset_list
|
||||
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
|
||||
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
|
||||
data = self.dataset_list[dataset_id][data_id]
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class NexusGenQwenVLEncoder(torch.nn.Module):
|
||||
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
|
||||
super().__init__()
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
|
||||
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
|
||||
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
|
||||
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
|
||||
|
||||
def process_images(self, images=None):
|
||||
if images is None:
|
||||
return None
|
||||
# resize input to max_pixels to avoid oom
|
||||
for j in range(len(images)):
|
||||
input_image = images[j]
|
||||
input_w, input_h = input_image.size
|
||||
resized_height, resized_width = smart_resize(
|
||||
input_h,
|
||||
input_w,
|
||||
max_pixels=262640,
|
||||
)
|
||||
images[j] = input_image.resize((resized_width, resized_height))
|
||||
return images
|
||||
|
||||
def forward(self, prompt, images=None, num_img_tokens=81):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": self.t2i_template if images is None else self.i2i_template
|
||||
},],
|
||||
}
|
||||
]
|
||||
images = self.process_images(images)
|
||||
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
|
||||
if images is None:
|
||||
images = [target_image]
|
||||
else:
|
||||
images = images + [target_image]
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
|
||||
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||
|
||||
image_mask = inputs['input_ids'] == self.model.config.image_token_id
|
||||
indices = image_mask.cumsum(dim=1)
|
||||
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||
|
||||
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
|
||||
inputs['image_grid_thw'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
position_ids = position_ids.contiguous()
|
||||
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
|
||||
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||
output_image_embeddings = output_image_embeddings.unsqueeze(0)
|
||||
return output_image_embeddings
|
||||
|
||||
|
||||
|
||||
class UnifiedModel(pl.LightningModule):
|
||||
def __init__(self, flux_text_encoder_path, flux_vae_path, flux_dit_path, flux_decoder_path, qwenvl_path):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([
|
||||
flux_text_encoder_path,
|
||||
flux_vae_path,
|
||||
flux_dit_path
|
||||
])
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
state_dict = load_state_dict(flux_decoder_path, torch_dtype=torch.bfloat16)
|
||||
self.pipe.dit.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16)
|
||||
self.adapter.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.qwenvl = NexusGenQwenVLEncoder.from_pretrained(qwenvl_path)
|
||||
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder_1.requires_grad_(False)
|
||||
self.qwenvl.requires_grad_(False)
|
||||
self.qwenvl.model.visual.requires_grad_(False)
|
||||
self.pipe.train()
|
||||
self.adapter.train()
|
||||
self.qwenvl.train()
|
||||
self.qwenvl.model.visual.eval()
|
||||
# self.qwenvl.model.model.gradient_checkpointing = True
|
||||
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["instruction"], batch["image_2"]
|
||||
image_ref = batch["image_1"]
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
if image_ref is None:
|
||||
instruction = f"Generate an image according to the following description: {text}"
|
||||
images_ref = None
|
||||
else:
|
||||
instruction = f"<|vision_start|><|image_pad|><|vision_end|> {text}"
|
||||
images_ref = [image_ref]
|
||||
emb = self.qwenvl(instruction, images=images_ref)
|
||||
emb = self.adapter(emb)
|
||||
prompt_emb = self.pipe.encode_prompt("", positive=True, image_emb=emb)
|
||||
|
||||
noise_pred = lets_dance_flux(
|
||||
self.pipe.denoising_model(),
|
||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
image_emb=emb,
|
||||
use_gradient_checkpointing=False
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
return self.training_step(batch, 0)
|
||||
|
||||
|
||||
|
||||
def search_for_last_checkpoint(path):
|
||||
if not os.path.exists(path):
|
||||
return None, 0
|
||||
checkpoint_list = os.listdir(path)
|
||||
checkpoint_list = [int(checkpoint.split("-")[1]) for checkpoint in checkpoint_list if checkpoint.startswith("epoch")]
|
||||
if len(checkpoint_list) == 0:
|
||||
return None, 0
|
||||
else:
|
||||
max_epoch_id = max(checkpoint_list)
|
||||
return f"{path}/epoch-{max_epoch_id}/model.safetensors", max_epoch_id + 1
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="gradient_accumulation_steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="steps_per_epoch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./models",
|
||||
help="output_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="learning_rate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
model = UnifiedModel(
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
|
||||
"models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin",
|
||||
"models/DiffSynth-Studio/Nexus-Gen",
|
||||
)
|
||||
# dataset and data loader
|
||||
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
||||
dataset = MultiTaskDataset(
|
||||
dataset_list=[
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
|
||||
),
|
||||
],
|
||||
dataset_weight=(4, 2, 1,),
|
||||
steps_per_epoch=args.steps_per_epoch * accelerator.num_processes,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
collate_fn=lambda x: x[0]
|
||||
)
|
||||
# train
|
||||
pretrained_path, start_epoch_id = search_for_last_checkpoint(args.output_path)
|
||||
if pretrained_path is not None:
|
||||
print(f"pretrained_path: {pretrained_path}")
|
||||
model.load_state_dict(load_state_dict(pretrained_path, torch_dtype=torch.bfloat16), assign=True, strict=False)
|
||||
|
||||
model.to(torch.bfloat16)
|
||||
model.to(accelerator.device)
|
||||
|
||||
trainable_modules = filter(lambda p: p.requires_grad, model.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=args.learning_rate)
|
||||
|
||||
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
|
||||
|
||||
for epoch in range(start_epoch_id, 100000):
|
||||
for batch in tqdm(train_loader, desc=f"epoch-{epoch}", disable=not accelerator.is_local_main_process):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
path = args.output_path
|
||||
os.makedirs(path, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_model(model, f"{path}/epoch-{epoch}", max_shard_size="10GB", safe_serialization=True)
|
||||
Reference in New Issue
Block a user