Compare commits

...

7 Commits

Author SHA1 Message Date
Zhongjie Duan
2b72ae0e56 Merge pull request #568 from mi804/nexus-gen
update nexus-gen readme
2025-05-13 14:11:25 +08:00
mi804
90588dcf97 update nexus-gen readme 2025-05-13 11:42:01 +08:00
xuyixuan.xyx
91fbb24e17 refine training 2025-05-12 14:19:00 +08:00
xuyixuan.xyx
f17558a4c4 train 2025-05-07 11:22:13 +08:00
Artiprocher
290ec469ca train 2025-05-06 17:54:32 +08:00
Artiprocher
1ed676b076 train 2025-05-06 17:53:56 +08:00
Artiprocher
f7737aff98 nexus-gen 2025-04-30 17:09:15 +08:00
13 changed files with 3913 additions and 24 deletions

View File

@@ -42,6 +42,12 @@ Until now, DiffSynth-Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News ## News
- **May 1, 2025** 🔥🔥🔥 We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models.
- Paper: [Nexus-Gen: A Unified Model for Image Understanding, Generation, and Editing](https://arxiv.org/pdf/2504.21356)
- Github Repo: https://github.com/modelscope/Nexus-Gen
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-Gen), [HuggingFace](https://huggingface.co/modelscope/Nexus-Gen)
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details. - **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality. - **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.

View File

@@ -202,10 +202,10 @@ class FluxImagePipeline(BasePipeline):
return image return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512): def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, image_emb=None):
if self.text_encoder_1 is not None and self.text_encoder_2 is not None: if (self.text_encoder_1 is not None and self.text_encoder_2 is not None) or (image_emb is not None):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, image_emb=image_emb
) )
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
else: else:
@@ -358,13 +358,13 @@ class FluxImagePipeline(BasePipeline):
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale): def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb=None):
# Extend prompt # Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts # Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length, image_emb=image_emb)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@@ -432,6 +432,7 @@ class FluxImagePipeline(BasePipeline):
height=1024, height=1024,
width=1024, width=1024,
seed=None, seed=None,
image_emb=None,
# Steps # Steps
num_inference_steps=30, num_inference_steps=30,
# local prompts # local prompts
@@ -483,7 +484,7 @@ class FluxImagePipeline(BasePipeline):
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride) latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
# Prompt # Prompt
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale) prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb)
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
@@ -679,6 +680,7 @@ def lets_dance_flux(
step1x_mask=None, step1x_mask=None,
step1x_reference_latents=None, step1x_reference_latents=None,
tea_cache: TeaCache = None, tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs **kwargs
): ):
if tiled: if tiled:
@@ -773,20 +775,32 @@ def lets_dance_flux(
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else: else:
tea_cache_update = False tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update: if tea_cache_update:
hidden_states = tea_cache.update(hidden_states) hidden_states = tea_cache.update(hidden_states)
else: else:
# Joint Blocks # Joint Blocks
for block_id, block in enumerate(dit.blocks): for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block( if use_gradient_checkpointing:
hidden_states, hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
prompt_emb, create_custom_forward(block),
conditioning, hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
image_rotary_emb, use_reentrant=False,
attention_mask, )
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) else:
) hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id] hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -795,14 +809,21 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks) num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks): for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block( if use_gradient_checkpointing:
hidden_states, hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
prompt_emb, create_custom_forward(block),
conditioning, hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
image_rotary_emb, use_reentrant=False,
attention_mask, )
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) else:
) hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]

View File

@@ -59,6 +59,7 @@ class FluxPrompter(BasePrompter):
positive=True, positive=True,
device="cuda", device="cuda",
t5_sequence_length=512, t5_sequence_length=512,
image_emb=None,
): ):
prompt = self.process_prompt(prompt, positive=positive) prompt = self.process_prompt(prompt, positive=positive)
@@ -66,7 +67,10 @@ class FluxPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5 # T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device) if image_emb is not None:
prompt_emb = image_emb
else:
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids # text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)

0
modeling/ar/__init__.py Normal file
View File

View File

@@ -0,0 +1,258 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_qwen2_5_vl.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class Qwen2_5_VLVisionConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=32,
hidden_size=3584,
hidden_act="silu",
intermediate_size=3420,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
tokens_per_second=4,
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
class Qwen2_5_VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
hidden_size (`int`, *optional*, defaults to 8192):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 29568):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 80):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 80):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2_5_VL`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self, ignore_keys={"mrope_section"})
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Qwen2_5_VLConfig"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,235 @@
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List, Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, VideoInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
fps: Union[List[float], float]
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"videos_kwargs": {"fps": 2.0},
}
class Qwen2_5_VLProcessor(ProcessorMixin):
r"""
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
videos: VideoInput = None,
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen2_5_VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if isinstance(fps, (int, float)):
second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
else:
videos_inputs = {}
video_grid_thw = None
if not isinstance(text, list):
text = [text]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
text[i] = text[i].replace(
self.image_token,
"<|placeholder|>" * (image_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
text[i] = text[i].replace(
self.video_token,
"<|placeholder|>" * (video_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def batch_decode_all2all(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
decoded = self.tokenizer.batch_decode(*args, **kwargs)
pattern = r'<\|vision_start\|>.*?<\|vision_end\|>'
decoded_with_image_tag = [re.sub(pattern, '<image>', d, flags=re.DOTALL) for d in decoded]
decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag]
return decoded_with_image_tag
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
return names_from_processor + ["second_per_grid_ts"]
__all__ = ["Qwen2_5_VLProcessor"]

View File

@@ -0,0 +1,64 @@
import torch
from diffsynth import ModelManager
from .flux_image_pipeline import FluxImagePipelineAll2All
class FluxDecoder:
def __init__(self, flux_all2all_modelpath, flux_path, device='cuda', torch_dtype=torch.bfloat16):
self.device = device
self.torch_dtype = torch_dtype
self.pipe, self.adapter = self.get_pipe(flux_all2all_modelpath, flux_path, device, torch_dtype)
def get_pipe(self, flux_all2all_modelpath, flux_path, device="cuda", torch_dtype=torch.bfloat16):
model_manager = ModelManager(torch_dtype=torch_dtype, device=device)
model_manager.load_models([
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder_2",
f"{flux_path}/FLUX/FLUX.1-dev/ae.safetensors",
f"{flux_path}/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
state_dict = torch.load(flux_all2all_modelpath, weights_only=True, map_location='cpu')
adapter_states = ['0.weight', '0.bias', '1.weight', '1.bias', '3.weight', '3.bias', '4.weight', '4.bias']
adapter_state_dict = {}
for key in adapter_states:
adapter_state_dict[key] = state_dict.pop(key)
in_channel = 3584
out_channel = 4096
expand_ratio = 1
adapter = torch.nn.Sequential(torch.nn.Linear(in_channel, out_channel * expand_ratio),
torch.nn.LayerNorm(out_channel * expand_ratio), torch.nn.ReLU(),
torch.nn.Linear(out_channel * expand_ratio, out_channel),
torch.nn.LayerNorm(out_channel))
adapter.load_state_dict(adapter_state_dict)
adapter.to(device, dtype=torch_dtype)
pipe = FluxImagePipelineAll2All.from_model_manager(model_manager)
pipe.dit.load_state_dict(state_dict)
return pipe, adapter
@torch.no_grad()
def decode_image_embeds(self,
output_image_embeddings,
height=512,
width=512,
num_inference_steps=50,
seed=42,
negative_prompt="",
cfg_scale=1.0,
**pipe_kwargs):
output_image_embeddings = output_image_embeddings.to(device=self.device, dtype=self.torch_dtype)
image_embed = self.adapter(output_image_embeddings)
image = self.pipe(prompt="",
image_embed=image_embed,
num_inference_steps=num_inference_steps,
embedded_guidance=3.5,
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
height=height,
width=width,
seed=seed,
**pipe_kwargs)
return image

View File

@@ -0,0 +1,192 @@
from typing import List
from tqdm import tqdm
import torch
from diffsynth.models import ModelManager
from diffsynth.controlnets import ControlNetConfigUnit
from diffsynth.prompters.flux_prompter import FluxPrompter
from diffsynth.pipelines.flux_image import FluxImagePipeline, lets_dance_flux, TeaCache
class FluxPrompterAll2All(FluxPrompter):
def encode_prompt(
self,
prompt,
positive=True,
device="cuda",
t5_sequence_length=512,
clip_only=False
):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
if clip_only:
return None, pooled_prompt_emb, None
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
return prompt_emb, pooled_prompt_emb, text_ids
class FluxImagePipelineAll2All(FluxImagePipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.prompter = FluxPrompterAll2All()
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, clip_only=False):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, clip_only=clip_only
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
pipe = FluxImagePipelineAll2All(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe
def prepare_prompts(self, prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
if image_embed is not None:
image_embed = image_embed.to(self.torch_dtype)
prompt_emb_posi = self.encode_prompt("", positive=True, clip_only=True)
if len(image_embed.size()) == 2:
image_embed = image_embed.unsqueeze(0)
prompt_emb_posi['prompt_emb'] = image_embed
prompt_emb_posi['text_ids'] = torch.zeros(image_embed.shape[0], image_embed.shape[1], 3).to(device=self.device, dtype=self.torch_dtype)
else:
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@torch.no_grad()
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
t5_sequence_length=512,
# Image
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
seed=None,
# image_embed
image_embed=None,
# Steps
num_inference_steps=30,
# local prompts
local_prompts=(),
masks=(),
mask_scales=(),
# ControlNet
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
# EliGen
eligen_entity_prompts=None,
eligen_entity_masks=None,
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
tiled=False,
tile_size=128,
tile_stride=64,
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
# Prompt
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Entity control
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
# IP-Adapter
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
)
# Inpaint
if enable_eligen_inpaint:
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
# Classifier-free guidance
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Iterate
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, **tiler_kwargs)
# Offload all models
self.load_models_to_device([])
return image

View File

@@ -1,7 +1,7 @@
torch>=2.0.0 torch>=2.0.0
torchvision torchvision
cupy-cuda12x cupy-cuda12x
transformers==4.46.2 transformers==4.49.0
controlnet-aux==0.0.7 controlnet-aux==0.0.7
imageio imageio
imageio[ffmpeg] imageio[ffmpeg]
@@ -11,3 +11,4 @@ sentencepiece
protobuf protobuf
modelscope modelscope
ftfy ftfy
qwen_vl_utils

4
run_single.sh Normal file
View File

@@ -0,0 +1,4 @@
accelerate launch \
train.py \
--output_path models/nexus_v3 \
--steps_per_epoch 4000

312
test.py Normal file
View File

@@ -0,0 +1,312 @@
from transformers import AutoConfig, AutoTokenizer
import torch, json, os, torchvision
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
from qwen_vl_utils import smart_resize
from PIL import Image
import numpy as np
from torchvision.transforms import v2
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path,
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None
):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path, skip_process=False):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
if skip_process:
return image
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
def __len__(self):
return self.steps_per_epoch
class NexusGenQwenVLEncoder(torch.nn.Module):
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
super().__init__()
model_config = AutoConfig.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
@staticmethod
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
def process_images(self, images=None):
if images is None:
return None
# resize input to max_pixels to avoid oom
for j in range(len(images)):
input_image = images[j]
input_w, input_h = input_image.size
resized_height, resized_width = smart_resize(
input_h,
input_w,
max_pixels=262640,
)
images[j] = input_image.resize((resized_width, resized_height))
return images
def forward(self, prompt, images=None, num_img_tokens=81):
messages = [
{
"role": "user",
"content": [{
"type": "text",
"text": prompt
},],
},
{
"role": "assistant",
"content": [{
"type": "text",
"text": self.t2i_template if images is None else self.i2i_template
},],
}
]
images = self.process_images(images)
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
if images is None:
images = [target_image]
else:
images = images + [target_image]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
inputs = self.processor(
text=[text],
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == self.model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
output_image_embeddings = output_image_embeddings.unsqueeze(0)
return output_image_embeddings
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
# state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
# pipe.dit.load_state_dict(state_dict, strict=False)
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
# adapter.load_state_dict(state_dict, strict=False)
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
sd = {}
for i in range(1, 6):
print(i)
sd.update(load_state_dict(f"models/nexus_v3/epoch-19/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16))
pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")})
qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")})
adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")})
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
),
],
dataset_weight=(4, 2, 1,),
steps_per_epoch=100000
)
torch.manual_seed(0)
for data_id, data in enumerate(dataset):
image_1 = data["image_1"]
image_2 = data["image_2"].cpu().float().permute(1, 2, 0).numpy()
image_2 = Image.fromarray(((image_2 / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
instruction = data["instruction"]
print(instruction)
if image_1 is None:
with torch.no_grad():
instruction = f"Generate an image according to the following description: {instruction}"
emb = qwenvl(instruction, images=None)
emb = adapter(emb)
image_3 = pipe("", image_emb=emb)
else:
with torch.no_grad():
instruction = f"<|vision_start|><|image_pad|><|vision_end|> {instruction}"
emb = qwenvl(instruction, images=[image_1])
emb = adapter(emb)
image_3 = pipe("", image_emb=emb)
if image_1 is not None:
image_1.save(f"data/output/{data_id}_1.jpg")
image_2.save(f"data/output/{data_id}_2.jpg")
image_3.save(f"data/output/{data_id}_3.jpg")
if data_id >= 100:
break
# with torch.no_grad():
# instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic"
# emb = qwenvl(instruction, images=None)
# emb = adapter(emb)
# image = pipe("", image_emb=emb)
# image.save("image_1.jpg")
# with torch.no_grad():
# instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression."
# emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
# emb = adapter(emb)
# image = pipe("", image_emb=emb)
# image.save("image_2.jpg")

421
train.py Normal file
View File

@@ -0,0 +1,421 @@
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
from diffsynth.pipelines.flux_image import lets_dance_flux
from accelerate import Accelerator
from tqdm import tqdm
import torch, os, json, torchvision
from PIL import Image
from torchvision.transforms import v2
from transformers import AutoConfig, AutoTokenizer
import torch
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
from qwen_vl_utils import smart_resize
from PIL import Image
import numpy as np
import lightning as pl
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path,
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None
):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path, skip_process=False):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
if skip_process:
return image
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
while True:
try:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
except:
continue
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
def __len__(self):
return self.steps_per_epoch
class NexusGenQwenVLEncoder(torch.nn.Module):
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
super().__init__()
model_config = AutoConfig.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
@staticmethod
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
def process_images(self, images=None):
if images is None:
return None
# resize input to max_pixels to avoid oom
for j in range(len(images)):
input_image = images[j]
input_w, input_h = input_image.size
resized_height, resized_width = smart_resize(
input_h,
input_w,
max_pixels=262640,
)
images[j] = input_image.resize((resized_width, resized_height))
return images
def forward(self, prompt, images=None, num_img_tokens=81):
messages = [
{
"role": "user",
"content": [{
"type": "text",
"text": prompt
},],
},
{
"role": "assistant",
"content": [{
"type": "text",
"text": self.t2i_template if images is None else self.i2i_template
},],
}
]
images = self.process_images(images)
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
if images is None:
images = [target_image]
else:
images = images + [target_image]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
inputs = self.processor(
text=[text],
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == self.model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
output_image_embeddings = output_image_embeddings.unsqueeze(0)
return output_image_embeddings
class UnifiedModel(pl.LightningModule):
def __init__(self, flux_text_encoder_path, flux_vae_path, flux_dit_path, flux_decoder_path, qwenvl_path):
super().__init__()
# Load models
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
flux_text_encoder_path,
flux_vae_path,
flux_dit_path
])
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
state_dict = load_state_dict(flux_decoder_path, torch_dtype=torch.bfloat16)
self.pipe.dit.load_state_dict(state_dict, strict=False)
self.adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16)
self.adapter.load_state_dict(state_dict, strict=False)
self.qwenvl = NexusGenQwenVLEncoder.from_pretrained(qwenvl_path)
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder_1.requires_grad_(False)
self.qwenvl.requires_grad_(False)
self.qwenvl.model.visual.requires_grad_(False)
self.pipe.train()
self.adapter.train()
self.qwenvl.train()
self.qwenvl.model.visual.eval()
# self.qwenvl.model.model.gradient_checkpointing = True
self.pipe.scheduler.set_timesteps(1000, training=True)
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
image = image.unsqueeze(0)
# Prepare input parameters
self.pipe.device = self.device
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
if image_ref is None:
instruction = f"Generate an image according to the following description: {text}"
images_ref = None
else:
instruction = f"<|vision_start|><|image_pad|><|vision_end|> {text}"
images_ref = [image_ref]
emb = self.qwenvl(instruction, images=images_ref)
emb = self.adapter(emb)
prompt_emb = self.pipe.encode_prompt("", positive=True, image_emb=emb)
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
image_emb=emb,
use_gradient_checkpointing=False
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
return loss
def forward(self, batch):
return self.training_step(batch, 0)
def search_for_last_checkpoint(path):
if not os.path.exists(path):
return None, 0
checkpoint_list = os.listdir(path)
checkpoint_list = [int(checkpoint.split("-")[1]) for checkpoint in checkpoint_list if checkpoint.startswith("epoch")]
if len(checkpoint_list) == 0:
return None, 0
else:
max_epoch_id = max(checkpoint_list)
return f"{path}/epoch-{max_epoch_id}/model.safetensors", max_epoch_id + 1
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="gradient_accumulation_steps",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=1000,
help="steps_per_epoch",
)
parser.add_argument(
"--output_path",
type=str,
default="./models",
help="output_path",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="learning_rate",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = UnifiedModel(
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
"models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin",
"models/DiffSynth-Studio/Nexus-Gen",
)
# dataset and data loader
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
),
],
dataset_weight=(4, 2, 1,),
steps_per_epoch=args.steps_per_epoch * accelerator.num_processes,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=1,
collate_fn=lambda x: x[0]
)
# train
pretrained_path, start_epoch_id = search_for_last_checkpoint(args.output_path)
if pretrained_path is not None:
print(f"pretrained_path: {pretrained_path}")
model.load_state_dict(load_state_dict(pretrained_path, torch_dtype=torch.bfloat16), assign=True, strict=False)
model.to(torch.bfloat16)
model.to(accelerator.device)
trainable_modules = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=args.learning_rate)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
for epoch in range(start_epoch_id, 100000):
for batch in tqdm(train_loader, desc=f"epoch-{epoch}", disable=not accelerator.is_local_main_process):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
path = args.output_path
os.makedirs(path, exist_ok=True)
accelerator.wait_for_everyone()
accelerator.save_model(model, f"{path}/epoch-{epoch}", max_shard_size="10GB", safe_serialization=True)