From b8f05bb3425861ced8bc3777f7a2b998f5f82122 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 28 Jul 2025 11:09:33 +0800 Subject: [PATCH 1/8] tmp commit --- diffsynth/configs/model_config.py | 2 ++ diffsynth/models/flux_dit.py | 5 ++++- .../model_inference/Nexus-Gen-Generation.py | 21 +++++++++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 examples/flux/model_inference/Nexus-Gen-Generation.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b60c200..9fa652d 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -69,6 +69,7 @@ from ..models.flux_value_control import SingleValueEncoder from ..lora.flux_lora import FluxLoraPatcher from ..models.flux_lora_encoder import FluxLoRAEncoder +from ..models.nexus_gen_projector import NexusGenAdapter model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -152,6 +153,7 @@ model_loader_configs = [ (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), (None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"), + (None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus-gen_adapter"], [FluxDiT, NexusGenAdapter], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index ea5ce21..3dd728d 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -2,7 +2,7 @@ import torch from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm from einops import rearrange from .tiler import TileWorker -from .utils import init_weights_on_device +from .utils import init_weights_on_device, hash_state_dict_keys def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0): batch_size, num_tokens = hidden_states.shape[0:2] @@ -662,6 +662,9 @@ class FluxDiTStateDictConverter: return state_dict_ def from_civitai(self, state_dict): + if hash_state_dict_keys(state_dict, with_shape=True) == "3e6c61b0f9471135fc9c6d6a98e98b6d": + dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if not key.startswith('adapter.')} + return dit_state_dict rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight", diff --git a/examples/flux/model_inference/Nexus-Gen-Generation.py b/examples/flux/model_inference/Nexus-Gen-Generation.py new file mode 100644 index 0000000..102b7ef --- /dev/null +++ b/examples/flux/model_inference/Nexus-Gen-Generation.py @@ -0,0 +1,21 @@ +import importlib +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==0.49.0, please install it with `pip install transformers==0.49.0`." + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) From 2861ec4d9f9bbec6eacd93984d85b75a3da5be17 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 28 Jul 2025 16:18:38 +0800 Subject: [PATCH 2/8] tmp commit for nexus-gen edit --- diffsynth/configs/model_config.py | 7 +- diffsynth/models/flux_dit.py | 4 +- diffsynth/models/nexus_gen.py | 100 ++ diffsynth/models/nexus_gen_ar_model.py | 1143 +++++++++++++++++ diffsynth/models/nexus_gen_projector.py | 359 ++++++ diffsynth/pipelines/flux_image_new.py | 66 + .../flux/model_inference/Nexus-Gen-Editing.py | 34 + .../model_inference/Nexus-Gen-Generation.py | 14 +- 8 files changed, 1721 insertions(+), 6 deletions(-) create mode 100644 diffsynth/models/nexus_gen.py create mode 100644 diffsynth/models/nexus_gen_ar_model.py create mode 100644 diffsynth/models/nexus_gen_projector.py create mode 100644 examples/flux/model_inference/Nexus-Gen-Editing.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 9fa652d..d6517ba 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -69,7 +69,8 @@ from ..models.flux_value_control import SingleValueEncoder from ..lora.flux_lora import FluxLoraPatcher from ..models.flux_lora_encoder import FluxLoRAEncoder -from ..models.nexus_gen_projector import NexusGenAdapter +from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger +from ..models.nexus_gen import NexusGenAutoregressiveModel model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -153,7 +154,9 @@ model_loader_configs = [ (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), (None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"), - (None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus-gen_adapter"], [FluxDiT, NexusGenAdapter], "civitai"), + (None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus_gen_generation_adapter"], [FluxDiT, NexusGenAdapter], "civitai"), + (None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"), + (None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 3dd728d..0ec9b07 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -662,8 +662,8 @@ class FluxDiTStateDictConverter: return state_dict_ def from_civitai(self, state_dict): - if hash_state_dict_keys(state_dict, with_shape=True) == "3e6c61b0f9471135fc9c6d6a98e98b6d": - dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if not key.startswith('adapter.')} + if hash_state_dict_keys(state_dict, with_shape=True) in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]: + dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if key.startswith('pipe.dit.')} return dit_state_dict rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", diff --git a/diffsynth/models/nexus_gen.py b/diffsynth/models/nexus_gen.py new file mode 100644 index 0000000..f7a771e --- /dev/null +++ b/diffsynth/models/nexus_gen.py @@ -0,0 +1,100 @@ +import torch +from PIL import Image +from qwen_vl_utils import smart_resize +from transformers import AutoConfig +from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor + + +class NexusGenAutoregressiveModel(torch.nn.Module): + def __init__(self, model_path="models/DiffSynth-Studio/Nexus-GenV2", max_length=1024, max_pixels=262640, dtype=torch.bfloat16, device="cuda"): + super(NexusGenAutoregressiveModel, self).__init__() + self.max_length = max_length + self.max_pixels = max_pixels + model_config = AutoConfig.from_pretrained(model_path) + self.model = Qwen2_5_VLForConditionalGeneration(model_config) + self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + + + @staticmethod + def state_dict_converter(): + return NexusGenAutoregressiveModelStateDictConverter() + + def bound_image(self, image, max_pixels=262640): + resized_height, resized_width = smart_resize( + image.height, + image.width, + max_pixels=max_pixels, + ) + return image.resize((resized_width, resized_height)) + + def get_editing_msg(self, instruction): + if '' not in instruction: + instruction = ' ' + instruction + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: "}] + return messages + + def get_generation_msg(self, instruction): + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: "}] + return messages + + def forward(self, instruction, ref_image=None, num_img_tokens=81): + """ + Generate target embeddings for the given instruction and reference image. + """ + if ref_image is not None: + messages = self.get_editing_msg(instruction) + images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + else: + messages = self.get_generation_msg(instruction) + images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + + return output_image_embeddings + + def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81): + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + text = text.replace('', '<|vision_start|><|image_pad|><|vision_end|>') + inputs = processor( + text=[text], + images=images, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(model.device) + + input_embeds = model.model.embed_tokens(inputs['input_ids']) + image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw']) + ground_truth_image_embeds = image_embeds[-num_img_tokens:] + input_image_embeds = image_embeds[:-num_img_tokens] + + image_mask = inputs['input_ids'] == model.config.image_token_id + indices = image_mask.cumsum(dim=1) + input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask) + gt_image_mask = torch.logical_and(image_mask, ~input_image_mask) + input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds) + + image_prefill_embeds = model.image_prefill_embeds( + torch.arange(81, device=model.device).long() + ) + input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds) + + position_ids, _ = model.get_rope_index(inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) + position_ids = position_ids.contiguous() + outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) + output_image_embeddings = outputs.image_embeddings[:, :-1, :] + output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]] + return output_image_embeddings, input_image_embeds, inputs['image_grid_thw'] + + +class NexusGenAutoregressiveModelStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict = {"model." + key: value for key, value in state_dict.items()} + return state_dict + \ No newline at end of file diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py new file mode 100644 index 0000000..d5a2973 --- /dev/null +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -0,0 +1,1143 @@ +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin, LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput +from transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.modeling_outputs import ModelOutput +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + QWEN2_5_VL_INPUTS_DOCSTRING, + ) + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + image_embeddings: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vision_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + self.image_prefill_embeds = nn.Embedding(81, config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + token_loss_weight: Optional[float] = 0.1, + img_loss_weight: Optional[float] = 1.0, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # test feature + inputs_embeds = self.model.embed_tokens(input_ids) + # for image encoding and training + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + # position_ids [3, B, L] + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + image_embeds = self.vision_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + # prepare labels for logits + logits_labels = labels.clone().detach() + image_tokens = (labels == self.config.image_token_id) + logits_labels[image_tokens] = -100 + + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = logits_labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) * token_loss_weight + + shift_image_tokens_2d = (labels[..., 1:].contiguous() == self.config.image_token_id) # (B, L-1) + shifted_image_embeds = image_embeds[:, :-1, :].contiguous() # (B, L-1, D) + masked_image_embeds = shifted_image_embeds[shift_image_tokens_2d] # (num_image_tokens, D) + + mse_loss_fct = nn.MSELoss() + mse_loss_fct = mse_loss_fct.to(shift_logits.device) + if image_embeddings is None: + image_embeddings = torch.zeros_like(masked_image_embeds) + img_loss = mse_loss_fct(masked_image_embeds, image_embeddings) + + cos_sim = torch.cosine_similarity( + masked_image_embeds, + image_embeddings, + dim=-1 + ) + cos_loss = (1 - cos_sim).mean() + img_loss = 0.5 * img_loss + 0.5 * cos_loss + # fix nan for empty image tokens + if image_embeddings.size(0) == 0: + img_loss = img_loss.nan_to_num(0.0) + # combine the loss + loss = loss + img_loss_weight * img_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + image_embeddings=image_embeds, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + model_forward = self.__call__ + if isinstance(model_kwargs.get("past_key_values"), Cache): + is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache + is_compileable = is_compileable and not self.generation_config.disable_compile + if is_compileable and ( + self.device.type == "cuda" or generation_config.compile_config._compile_all_devices + ): + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) + + is_prefill = True + is_sampling_img = input_ids[:, -1] == self.config.vision_start_token_id + generation_image_grid_thw = model_kwargs.pop("generation_image_grid_thw", self.get_default_image_grid_thw()) + num_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) + output_image_embeddings = [] + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare prefilled embeds + model_inputs.update(self.prepare_prefilled_image_embeds(len(output_image_embeddings), num_img_tokens, is_sampling_img, **model_kwargs)) + + # parse position_ids from model_kwargs + model_inputs.update(self.prepare_image_position_ids(input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs)) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + # TODO: support batch image sampling + if bool(is_sampling_img) and len(output_image_embeddings) < num_img_tokens: + output_image_embeddings.append(outputs.image_embeddings[:, -1, :].unsqueeze(1)) + + if synced_gpus and this_peer_finished: + continue + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + + # do not sample token + next_token_logits[:, self.config.vision_end_token_id] = -float('inf') + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + # while not bool(is_sampling_img) and torch.any(next_tokens == self.config.vision_end_token_id): + # probs[:, self.config.vision_end_token_id] = 0 + # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + #TODO: support batch image sample + if num_img_tokens is not None: + cur_img_tokens = (input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1) + # check whether is sampling images + is_end_img = torch.logical_and(cur_img_tokens == num_img_tokens, is_sampling_img) + is_sampling_img = torch.logical_and(is_sampling_img, cur_img_tokens < num_img_tokens) + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to end sampling images + next_tokens[is_end_img] = self.config.vision_end_token_id + else: + # check whether to end sampling images + is_sampling_img = torch.logical_and(is_sampling_img, (next_tokens != self.config.vision_end_token_id)) + # replace the next token with the image token if is sampling image + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to start sampling images + is_sampling_img = torch.logical_or(is_sampling_img, (next_tokens == self.config.vision_start_token_id)) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + # output the image embeddings + output_image_embeddings = torch.cat(output_image_embeddings, dim=1) if len(output_image_embeddings) > 0 else None + + if return_dict_in_generate: + return GenerateDecoderOnlyAll2AllOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + output_image_embeddings=output_image_embeddings, + ) + else: + return input_ids + + + def prepare_prefilled_image_embeds(self, cur_image_tokens, num_img_tokens, is_sampling_img, **model_kwargs): + if cur_image_tokens == 0 or cur_image_tokens > num_img_tokens or not bool(is_sampling_img): + return {} + # TODO: support batch image sample + image_idx = torch.tensor([cur_image_tokens-1]).to(self.device).long().unsqueeze(0) + inputs_embeds = self.image_prefill_embeds(image_idx) + return {"inputs_embeds": inputs_embeds} + + + def get_default_image_grid_thw(self,): + return torch.tensor([[1, 18, 18]]).to(self.device) + + + def get_num_image_tokens(self, image_grid_thw): + return int(torch.prod(image_grid_thw, dim=1).sum() // 4) + + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + num_img_tokens = model_kwargs.pop("generation_image_grid_thw", None) + super()._validate_model_kwargs(model_kwargs) + model_kwargs["generation_image_grid_thw"] = num_img_tokens + + def prepare_image_position_ids(self, input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs): + # Overwritten -- prepare position_ids for image tokens + cur_img_tokens = int((input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1)) + # TODO: support batch image sample + if cur_img_tokens > 0 and bool(is_sampling_img): + image_grid_thw = generation_image_grid_thw + if model_kwargs.get('image_grid_thw') is not None: + image_grid_thw = torch.cat([model_kwargs.get('image_grid_thw'), image_grid_thw]) + remaining_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) - cur_img_tokens + padding_ids = input_ids.new_full((1, remaining_img_tokens), fill_value=self.config.image_token_id) + padded_ids = torch.cat([input_ids, padding_ids], dim=1) + position_ids, _ = self.get_rope_index(padded_ids, image_grid_thw, None, None) + if model_kwargs.get("use_cache", True): + position_ids = position_ids[:, :, input_ids.shape[1] - 1].unsqueeze(-1) + else: + position_ids = position_ids[:, :, :input_ids.shape[1]] + return {"position_ids": position_ids} + return {} + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + image_embeddings=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepared with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] + + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def batch_decode_all2all(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + decoded = self.tokenizer.batch_decode(*args, **kwargs) + pattern = r'<\|vision_start\|>.*?<\|vision_end\|>' + decoded_with_image_tag = [re.sub(pattern, '', d, flags=re.DOTALL) for d in decoded] + decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag] + return decoded_with_image_tag + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["Qwen2_5_VLProcessor"] diff --git a/diffsynth/models/nexus_gen_projector.py b/diffsynth/models/nexus_gen_projector.py new file mode 100644 index 0000000..b35ff3f --- /dev/null +++ b/diffsynth/models/nexus_gen_projector.py @@ -0,0 +1,359 @@ +import math +import torch +import torch.nn as nn +from typing import Optional, Tuple +from transformers.activations import ACT2FN +from transformers.modeling_rope_utils import _compute_default_rope_parameters +from transformers import AutoConfig + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = _compute_default_rope_parameters + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NexusGenImageEmbeddingMerger(nn.Module): + def __init__(self, model_path="models/DiffSynth-Studio/Nexus-GenV2", num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): + super().__init__() + config = AutoConfig.from_pretrained(model_path) + self.config = config + self.num_layers = num_layers + self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)]) + self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps), + nn.Linear(config.hidden_size, out_channel * expand_ratio), + Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps), + ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel), + Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps)) + self.base_grid = torch.tensor([[1, 72, 72]], device=device) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device) + + def get_position_ids(self, image_grid_thw): + """ + Generates position ids for the input embeddings grid. + modified from the qwen2_vl mrope. + """ + batch_size = image_grid_thw.shape[0] + spatial_merge_size = self.config.vision_config.spatial_merge_size + t, h, w = ( + image_grid_thw[0][0], + image_grid_thw[0][1], + image_grid_thw[0][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + scale_h = self.base_grid[0][1].item() / h.item() + scale_w = self.base_grid[0][2].item() / w.item() + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + time_tensor = expanded_range * self.config.vision_config.tokens_per_second + t_index = time_tensor.long().flatten().to(image_grid_thw.device) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w + # 3, B, L + position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2) + return position_ids + + def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None): + position_ids = self.get_position_ids(embeds_grid) + hidden_states = embeds + if ref_embeds is not None: + position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid) + position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1) + hidden_states = torch.cat((embeds, ref_embeds), dim=1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings) + + hidden_states = self.projector(hidden_states) + return hidden_states + + @staticmethod + def state_dict_converter(): + return NexusGenMergerStateDictConverter() + + +class NexusGenMergerStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')} + return merger_state_dict + + +class NexusGenAdapter(nn.Module): + """ + Adapter for Nexus-Gen generation decoder. + """ + def __init__(self, input_dim=3584, output_dim=4096): + super(NexusGenAdapter, self).__init__() + self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim), + nn.LayerNorm(output_dim), nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.LayerNorm(output_dim)) + + def forward(self, x): + return self.adapter(x) + + @staticmethod + def state_dict_converter(): + return NexusGenAdapterStateDictConverter() + + +class NexusGenAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')} + return adapter_state_dict diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 2abd16c..36d7922 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -22,6 +22,8 @@ from ..models.flux_value_control import MultiValueEncoder from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock from ..models.tiler import FastTileWorker +from ..models.nexus_gen import NexusGenAutoregressiveModel +from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser @@ -94,6 +96,9 @@ class FluxImagePipeline(BasePipeline): self.ipadapter_image_encoder = None self.qwenvl = None self.step1x_connector: Qwen2Connector = None + self.nexus_gen: NexusGenAutoregressiveModel = None + self.nexus_gen_generation_adapter: NexusGenAdapter = None + self.nexus_gen_editing_adapter: NexusGenImageEmbeddingMerger = None self.value_controller: MultiValueEncoder = None self.infinityou_processor: InfinitYou = None self.image_proj_model: InfiniteYouImageProjector = None @@ -113,6 +118,7 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_ControlNet(), FluxImageUnit_IPAdapter(), FluxImageUnit_EntityControl(), + FluxImageUnit_NexusGen(), FluxImageUnit_TeaCache(), FluxImageUnit_Flex(), FluxImageUnit_Step1x(), @@ -397,6 +403,9 @@ class FluxImagePipeline(BasePipeline): pipe.infinityou_processor = InfinitYou(device=device) pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher") pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder") + pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm") + pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter") + pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter") # ControlNet controlnets = [] @@ -468,6 +477,8 @@ class FluxImagePipeline(BasePipeline): value_controller_inputs: Union[list[float], float] = None, # Step1x step1x_reference_image: Image.Image = None, + # NexusGen + nexus_gen_reference_image: Image.Image = None, # LoRA Encoder lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None, lora_encoder_scale: float = 1.0, @@ -504,6 +515,7 @@ class FluxImagePipeline(BasePipeline): "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, "value_controller_inputs": value_controller_inputs, "step1x_reference_image": step1x_reference_image, + "nexus_gen_reference_image": nexus_gen_reference_image, "lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale, "tea_cache_l1_thresh": tea_cache_l1_thresh, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, @@ -764,6 +776,60 @@ class FluxImageUnit_EntityControl(PipelineUnit): return inputs_shared, inputs_posi, inputs_nega +class FluxImageUnit_NexusGen(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"), + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.nexus_gen is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + if inputs_shared.get("nexus_gen_reference_image", None) is None: + assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set." + embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0) + inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed) + inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype) + else: + assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set." + embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"]) + embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long) + ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long) + + inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid) + inputs_posi["text_ids"] = self.get_editing_text_ids( + inputs_shared["latents"], + embeds_grid[0][1].item(), embeds_grid[0][2].item(), + ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(), + ) + return inputs_shared, inputs_posi, inputs_nega + + + def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width): + # prepare text ids for target and reference embeddings + batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width + embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype) + + batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width + ref_embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0 + ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype) + + text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1) + return text_ids + + class FluxImageUnit_Step1x(PipelineUnit): def __init__(self): super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder")) diff --git a/examples/flux/model_inference/Nexus-Gen-Editing.py b/examples/flux/model_inference/Nexus-Gen-Editing.py new file mode 100644 index 0000000..603ac33 --- /dev/null +++ b/examples/flux/model_inference/Nexus-Gen-Editing.py @@ -0,0 +1,34 @@ +import importlib +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import snapshot_download + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + +snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +prompt = "给猫加一副太阳镜" +ref_image = Image.open("cat.png").convert("RGB") +image = pipe( + prompt=prompt, negative_prompt="", + seed=0, cfg_scale=1.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("cat_glasses.jpg") diff --git a/examples/flux/model_inference/Nexus-Gen-Generation.py b/examples/flux/model_inference/Nexus-Gen-Generation.py index 102b7ef..07ef1d2 100644 --- a/examples/flux/model_inference/Nexus-Gen-Generation.py +++ b/examples/flux/model_inference/Nexus-Gen-Generation.py @@ -1,21 +1,31 @@ import importlib import torch from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import snapshot_download if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") else: import transformers - assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==0.49.0, please install it with `pip install transformers==0.49.0`." + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." +snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ], ) + +prompt = "一只可爱的猫咪" +image = pipe( + prompt=prompt, negative_prompt="", + seed=0, cfg_scale=3, num_inference_steps=50, + height=1024, width=1024, +) +image.save("cat.jpg") From 8ef91b36728947e3563af2925a0a6c861a6309a3 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 29 Jul 2025 13:28:42 +0800 Subject: [PATCH 3/8] support training for eligen and nexusgen --- README.md | 4 +-- README_zh.md | 4 +-- diffsynth/models/nexus_gen.py | 11 +++--- diffsynth/pipelines/flux_image_new.py | 3 +- diffsynth/trainers/utils.py | 9 +++-- .../flux/model_inference/Nexus-Gen-Editing.py | 11 +++--- .../Nexus-Gen-Editing.py | 36 +++++++++++++++++++ .../full/FLUX.1-NexusGen-Edit.sh | 14 ++++++++ .../full/accelerate_config_zero2offload.yaml | 22 ++++++++++++ .../lora/FLUX.1-NexusGen-Edit.sh | 17 +++++++++ .../model_training/lora/FLUX.1-dev-EliGen.sh | 17 +++++++++ .../validate_full/Nexus-Gen-Editing.py | 28 +++++++++++++++ .../validate_lora/FLUX.1-dev-EliGen.py | 33 +++++++++++++++++ .../validate_lora/Nexus-Gen-Editing.py | 26 ++++++++++++++ 14 files changed, 218 insertions(+), 17 deletions(-) create mode 100644 examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py create mode 100644 examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh create mode 100644 examples/flux/model_training/full/accelerate_config_zero2offload.yaml create mode 100644 examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh create mode 100644 examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh create mode 100644 examples/flux/model_training/validate_full/Nexus-Gen-Editing.py create mode 100644 examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py create mode 100644 examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py diff --git a/README.md b/README.md index 11403a5..f592abb 100644 --- a/README.md +++ b/README.md @@ -96,12 +96,12 @@ image.save("image.jpg") |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| +|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| - +|[Nexus-Gen-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py)| diff --git a/README_zh.md b/README_zh.md index 650d2ec..dc1b514 100644 --- a/README_zh.md +++ b/README_zh.md @@ -98,12 +98,12 @@ image.save("image.jpg") |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| +|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| - +|[Nexus-Gen-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py)| ### Wan 系列 diff --git a/diffsynth/models/nexus_gen.py b/diffsynth/models/nexus_gen.py index f7a771e..31475c7 100644 --- a/diffsynth/models/nexus_gen.py +++ b/diffsynth/models/nexus_gen.py @@ -14,7 +14,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module): self.model = Qwen2_5_VLForConditionalGeneration(model_config) self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path) - + @staticmethod def state_dict_converter(): return NexusGenAutoregressiveModelStateDictConverter() @@ -34,6 +34,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module): return messages def get_generation_msg(self, instruction): + instruction = "Generate an image according to the following description: {}".format(instruction) messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: "}] return messages @@ -80,9 +81,10 @@ class NexusGenAutoregressiveModel(torch.nn.Module): ) input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds) - position_ids, _ = model.get_rope_index(inputs['input_ids'], - inputs['image_grid_thw'], - attention_mask=inputs['attention_mask']) + position_ids, _ = model.get_rope_index( + inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) position_ids = position_ids.contiguous() outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) output_image_embeddings = outputs.image_embeddings[:, :-1, :] @@ -97,4 +99,3 @@ class NexusGenAutoregressiveModelStateDictConverter: def from_civitai(self, state_dict): state_dict = {"model." + key: value for key, value in state_dict.items()} return state_dict - \ No newline at end of file diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 36d7922..8f9ec61 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -767,9 +767,10 @@ class FluxImageUnit_EntityControl(PipelineUnit): if eligen_entity_prompts is None or eligen_entity_masks is None: return inputs_shared, inputs_posi, inputs_nega pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], - inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"]) + inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"]) inputs_posi.update(eligen_kwargs_posi) if inputs_shared.get("cfg_scale", 1.0) != 1.0: inputs_nega.update(eligen_kwargs_nega) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b171857..07e3664 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -120,8 +120,13 @@ class ImageDataset(torch.utils.data.Dataset): data = self.data[data_id % len(self.data)].copy() for key in self.data_file_keys: if key in data: - path = os.path.join(self.base_path, data[key]) - data[key] = self.load_data(path) + if isinstance(data[key], list): + print(f"Loading multiple files for key '{key}'.") + path = [os.path.join(self.base_path, p) for p in data[key]] + data[key] = [self.load_data(p) for p in path] + else: + path = os.path.join(self.base_path, data[key]) + data[key] = self.load_data(path) if data[key] is None: warnings.warn(f"cannot load file {data[key]}.") return None diff --git a/examples/flux/model_inference/Nexus-Gen-Editing.py b/examples/flux/model_inference/Nexus-Gen-Editing.py index 603ac33..f24f0c0 100644 --- a/examples/flux/model_inference/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference/Nexus-Gen-Editing.py @@ -2,7 +2,7 @@ import importlib import torch from PIL import Image from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download +from modelscope import snapshot_download, dataset_snapshot_download if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -23,12 +23,13 @@ pipe = FluxImagePipeline.from_pretrained( ], ) -prompt = "给猫加一副太阳镜" -ref_image = Image.open("cat.png").convert("RGB") +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") +ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB") +prompt = "Add a crown." image = pipe( prompt=prompt, negative_prompt="", - seed=0, cfg_scale=1.0, num_inference_steps=50, + seed=42, cfg_scale=2.0, num_inference_steps=50, nexus_gen_reference_image=ref_image, height=512, width=512, ) -image.save("cat_glasses.jpg") +image.save("cat_crown.jpg") diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py new file mode 100644 index 0000000..70a543f --- /dev/null +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py @@ -0,0 +1,36 @@ +import importlib +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import snapshot_download, dataset_snapshot_download + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + +snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") +ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB") +prompt = "Add a crown." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("cat_crown.jpg") diff --git a/examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh b/examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh new file mode 100644 index 0000000..ab1c324 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \ + --data_file_keys "image,nexus_gen_reference_image" \ + --max_pixels 262144 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-NexusGen-Edit_full" \ + --trainable_models "dit" \ + --extra_inputs "nexus_gen_reference_image" \ + --use_gradient_checkpointing_offload diff --git a/examples/flux/model_training/full/accelerate_config_zero2offload.yaml b/examples/flux/model_training/full/accelerate_config_zero2offload.yaml new file mode 100644 index 0000000..8a75f3d --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config_zero2offload.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: 'cpu' + offload_param_device: 'cpu' + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh b/examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh new file mode 100644 index 0000000..3e6eac1 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \ + --data_file_keys "image,nexus_gen_reference_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-NexusGen-Edit_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "nexus_gen_reference_image" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh b/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh new file mode 100644 index 0000000..10a18e0 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_eligen.json \ + --data_file_keys "image,eligen_entity_masks" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-EliGen_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "eligen_entity_masks,eligen_entity_prompts" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/validate_full/Nexus-Gen-Editing.py b/examples/flux/model_training/validate_full/Nexus-Gen-Editing.py new file mode 100644 index 0000000..5f7a2d2 --- /dev/null +++ b/examples/flux/model_training/validate_full/Nexus-Gen-Editing.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-NexusGen-Edit_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB") +prompt = "Add a pair of sunglasses." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("NexusGen-Edit_full.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py new file mode 100644 index 0000000..7df3db2 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-EliGen_lora/epoch-4.safetensors", alpha=1) + +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))] +# generate image +image = pipe( + prompt=global_prompt, + cfg_scale=1.0, + num_inference_steps=50, + embedded_guidance=3.5, + seed=42, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, +) +image.save(f"EliGen_lora.png") diff --git a/examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py b/examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py new file mode 100644 index 0000000..21c376f --- /dev/null +++ b/examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py @@ -0,0 +1,26 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-NexusGen-Edit_lora/epoch-4.safetensors", alpha=1) + +ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB") +prompt = "Add a pair of sunglasses." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=1.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("NexusGen-Edit_lora.jpg") From 7df48fc2b56ae6b84fb70fb19e938e8995606030 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 29 Jul 2025 13:33:14 +0800 Subject: [PATCH 4/8] remove debug out --- diffsynth/trainers/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 07e3664..8e51f18 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -121,7 +121,6 @@ class ImageDataset(torch.utils.data.Dataset): for key in self.data_file_keys: if key in data: if isinstance(data[key], list): - print(f"Loading multiple files for key '{key}'.") path = [os.path.join(self.base_path, p) for p in data[key]] data[key] = [self.load_data(p) for p in path] else: From 9c51623fc2b653b01cee0ec175a98c719fe6bd5d Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 29 Jul 2025 18:47:16 +0800 Subject: [PATCH 5/8] refine code --- README.md | 3 +- README_zh.md | 3 +- diffsynth/models/nexus_gen.py | 72 +++++++++++++++++-- diffsynth/models/nexus_gen_projector.py | 9 ++- diffsynth/pipelines/flux_image_new.py | 4 ++ examples/flux/README.md | 5 +- examples/flux/README_zh.md | 3 +- .../flux/model_inference/Nexus-Gen-Editing.py | 6 +- .../model_inference/Nexus-Gen-Generation.py | 5 +- .../Nexus-Gen-Generation.py | 32 +++++++++ .../{FLUX.1-NexusGen-Edit.sh => Nexus-Gen.sh} | 0 .../{FLUX.1-NexusGen-Edit.sh => Nexus-Gen.sh} | 0 .../{Nexus-Gen-Editing.py => Nexus-Gen.py} | 0 .../{Nexus-Gen-Editing.py => Nexus-Gen.py} | 0 14 files changed, 124 insertions(+), 18 deletions(-) create mode 100644 examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py rename examples/flux/model_training/full/{FLUX.1-NexusGen-Edit.sh => Nexus-Gen.sh} (100%) rename examples/flux/model_training/lora/{FLUX.1-NexusGen-Edit.sh => Nexus-Gen.sh} (100%) rename examples/flux/model_training/validate_full/{Nexus-Gen-Editing.py => Nexus-Gen.py} (100%) rename examples/flux/model_training/validate_lora/{Nexus-Gen-Editing.py => Nexus-Gen.py} (100%) diff --git a/README.md b/README.md index f592abb..dfea6d1 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,8 @@ image.save("image.jpg") |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| -|[Nexus-Gen-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py)| +|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)| + diff --git a/README_zh.md b/README_zh.md index dc1b514..2aae18a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -103,7 +103,8 @@ image.save("image.jpg") |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| -|[Nexus-Gen-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py)| +|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)| + ### Wan 系列 diff --git a/diffsynth/models/nexus_gen.py b/diffsynth/models/nexus_gen.py index 31475c7..0110398 100644 --- a/diffsynth/models/nexus_gen.py +++ b/diffsynth/models/nexus_gen.py @@ -1,18 +1,77 @@ import torch from PIL import Image -from qwen_vl_utils import smart_resize -from transformers import AutoConfig -from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor class NexusGenAutoregressiveModel(torch.nn.Module): - def __init__(self, model_path="models/DiffSynth-Studio/Nexus-GenV2", max_length=1024, max_pixels=262640, dtype=torch.bfloat16, device="cuda"): + def __init__(self, max_length=1024, max_pixels=262640): super(NexusGenAutoregressiveModel, self).__init__() + from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration + from transformers import Qwen2_5_VLConfig self.max_length = max_length self.max_pixels = max_pixels - model_config = AutoConfig.from_pretrained(model_path) + model_config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) self.model = Qwen2_5_VLForConditionalGeneration(model_config) - self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + self.processor = None + + + def load_processor(self, path): + from .nexus_gen_ar_model import Qwen2_5_VLProcessor + self.processor = Qwen2_5_VLProcessor.from_pretrained(path) @staticmethod @@ -20,6 +79,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module): return NexusGenAutoregressiveModelStateDictConverter() def bound_image(self, image, max_pixels=262640): + from qwen_vl_utils import smart_resize resized_height, resized_width = smart_resize( image.height, image.width, diff --git a/diffsynth/models/nexus_gen_projector.py b/diffsynth/models/nexus_gen_projector.py index b35ff3f..0adbafb 100644 --- a/diffsynth/models/nexus_gen_projector.py +++ b/diffsynth/models/nexus_gen_projector.py @@ -2,9 +2,8 @@ import math import torch import torch.nn as nn from typing import Optional, Tuple -from transformers.activations import ACT2FN -from transformers.modeling_rope_utils import _compute_default_rope_parameters -from transformers import AutoConfig + + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -39,6 +38,7 @@ class Qwen2_5_VLRotaryEmbedding(nn.Module): self.original_max_seq_len = config.max_position_embeddings self.config = config + from transformers.modeling_rope_utils import _compute_default_rope_parameters self.rope_init_fn = _compute_default_rope_parameters inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) @@ -181,6 +181,7 @@ class Qwen2_5_VLAttention(nn.Module): class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() + from transformers.activations import ACT2FN self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -254,6 +255,8 @@ class Qwen2_5_VLDecoderLayer(nn.Module): class NexusGenImageEmbeddingMerger(nn.Module): def __init__(self, model_path="models/DiffSynth-Studio/Nexus-GenV2", num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): super().__init__() + from transformers import AutoConfig + from transformers.activations import ACT2FN config = AutoConfig.from_pretrained(model_path) self.config = config self.num_layers = num_layers diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 8f9ec61..b750509 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -375,6 +375,7 @@ class FluxImagePipeline(BasePipeline): torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], + nexus_gen_processor_config: ModelConfig = None, ): # Download and load models model_manager = ModelManager() @@ -406,6 +407,9 @@ class FluxImagePipeline(BasePipeline): pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm") pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter") pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter") + if nexus_gen_processor_config is not None and pipe.nexus_gen is not None: + nexus_gen_processor_config.download_if_necessary() + pipe.nexus_gen.load_processor(nexus_gen_processor_config.path) # ControlNet controlnets = [] diff --git a/examples/flux/README.md b/examples/flux/README.md index a66e2bc..4ef0947 100644 --- a/examples/flux/README.md +++ b/examples/flux/README.md @@ -43,18 +43,19 @@ image.save("image.jpg") |Model ID|Extra Args|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |-|-|-|-|-|-|-|-| -|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev )||[code](./model_inference/FLUX.1-dev.py)|[code](./model_inference_low_vram/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)| +|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_inference_low_vram/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)| |[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)| |[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| |[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| +|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./model_training/validate_lora/FLUX.1-dev-EliGen.py)| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)| +|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./model_inference/Nexus-Gen-Editing.py)|[code](./model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./model_training/full/Nexus-Gen.sh)|[code](./model_training/validate_full/Nexus-Gen.py)|[code](./model_training/lora/Nexus-Gen.sh)|[code](./model_training/validate_lora/Nexus-Gen.py)| ## Model Inference diff --git a/examples/flux/README_zh.md b/examples/flux/README_zh.md index 3d3dc35..2e7b645 100644 --- a/examples/flux/README_zh.md +++ b/examples/flux/README_zh.md @@ -50,11 +50,12 @@ image.save("image.jpg") |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| +|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./model_training/validate_lora/FLUX.1-dev-EliGen.py)| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)| +|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./model_inference/Nexus-Gen-Editing.py)|[code](./model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./model_training/full/Nexus-Gen.sh)|[code](./model_training/validate_full/Nexus-Gen.py)|[code](./model_training/lora/Nexus-Gen.sh)|[code](./model_training/validate_lora/Nexus-Gen.py)| ## 模型推理 diff --git a/examples/flux/model_inference/Nexus-Gen-Editing.py b/examples/flux/model_inference/Nexus-Gen-Editing.py index f24f0c0..c9ab88c 100644 --- a/examples/flux/model_inference/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference/Nexus-Gen-Editing.py @@ -2,7 +2,8 @@ import importlib import torch from PIL import Image from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download, dataset_snapshot_download +from modelscope import dataset_snapshot_download + if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -10,7 +11,7 @@ else: import transformers assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." -snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") + pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -21,6 +22,7 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), ) dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") diff --git a/examples/flux/model_inference/Nexus-Gen-Generation.py b/examples/flux/model_inference/Nexus-Gen-Generation.py index 07ef1d2..dfe6880 100644 --- a/examples/flux/model_inference/Nexus-Gen-Generation.py +++ b/examples/flux/model_inference/Nexus-Gen-Generation.py @@ -1,7 +1,7 @@ import importlib import torch from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download + if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -9,7 +9,7 @@ else: import transformers assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." -snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") + pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -20,6 +20,7 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), ) prompt = "一只可爱的猫咪" diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py new file mode 100644 index 0000000..053b22b --- /dev/null +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py @@ -0,0 +1,32 @@ +import importlib +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import snapshot_download + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + +snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.enable_vram_management() + +prompt = "一只可爱的猫咪" +image = pipe( + prompt=prompt, negative_prompt="", + seed=0, cfg_scale=3, num_inference_steps=50, + height=1024, width=1024, +) +image.save("cat.jpg") diff --git a/examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh b/examples/flux/model_training/full/Nexus-Gen.sh similarity index 100% rename from examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh rename to examples/flux/model_training/full/Nexus-Gen.sh diff --git a/examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh b/examples/flux/model_training/lora/Nexus-Gen.sh similarity index 100% rename from examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh rename to examples/flux/model_training/lora/Nexus-Gen.sh diff --git a/examples/flux/model_training/validate_full/Nexus-Gen-Editing.py b/examples/flux/model_training/validate_full/Nexus-Gen.py similarity index 100% rename from examples/flux/model_training/validate_full/Nexus-Gen-Editing.py rename to examples/flux/model_training/validate_full/Nexus-Gen.py diff --git a/examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py b/examples/flux/model_training/validate_lora/Nexus-Gen.py similarity index 100% rename from examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py rename to examples/flux/model_training/validate_lora/Nexus-Gen.py From 03c8fd5e61dbeba7d1788f9388c4a90ae04e8278 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 29 Jul 2025 18:49:18 +0800 Subject: [PATCH 6/8] refine code --- .../Nexus-Gen-Editing.py | 16 +++++++++------- .../Nexus-Gen-Generation.py | 15 ++++++++------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py index 70a543f..313ce3c 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py @@ -2,7 +2,8 @@ import importlib import torch from PIL import Image from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download, dataset_snapshot_download +from modelscope import dataset_snapshot_download + if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -10,17 +11,18 @@ else: import transformers assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." -snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") + pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"), ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), ) pipe.enable_vram_management() diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py index 053b22b..c865271 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py @@ -1,7 +1,7 @@ import importlib import torch from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download + if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -9,17 +9,18 @@ else: import transformers assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." -snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") + pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"), ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), ) pipe.enable_vram_management() From 87ab7d020b4f5e8acfe1a60dede2ec7bf4e9dba4 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 29 Jul 2025 20:02:34 +0800 Subject: [PATCH 7/8] refine code --- diffsynth/models/nexus_gen_projector.py | 61 ++++++++++++++++++- diffsynth/pipelines/flux_image_new.py | 2 +- .../flux/model_inference/Nexus-Gen-Editing.py | 2 +- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/diffsynth/models/nexus_gen_projector.py b/diffsynth/models/nexus_gen_projector.py index 0adbafb..d69b3e1 100644 --- a/diffsynth/models/nexus_gen_projector.py +++ b/diffsynth/models/nexus_gen_projector.py @@ -253,11 +253,66 @@ class Qwen2_5_VLDecoderLayer(nn.Module): class NexusGenImageEmbeddingMerger(nn.Module): - def __init__(self, model_path="models/DiffSynth-Studio/Nexus-GenV2", num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): + def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): super().__init__() - from transformers import AutoConfig + from transformers import Qwen2_5_VLConfig from transformers.activations import ACT2FN - config = AutoConfig.from_pretrained(model_path) + config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) self.config = config self.num_layers = num_layers self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)]) diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index b750509..9384624 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -375,7 +375,7 @@ class FluxImagePipeline(BasePipeline): torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], - nexus_gen_processor_config: ModelConfig = None, + nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), ): # Download and load models model_manager = ModelManager() diff --git a/examples/flux/model_inference/Nexus-Gen-Editing.py b/examples/flux/model_inference/Nexus-Gen-Editing.py index c9ab88c..10351d5 100644 --- a/examples/flux/model_inference/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference/Nexus-Gen-Editing.py @@ -22,7 +22,7 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ], - nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), + nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), ) dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") From 2ed3860085fdf629e8874444c812e18524b21e38 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 29 Jul 2025 20:10:08 +0800 Subject: [PATCH 8/8] refine code --- examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py | 2 +- examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py index 313ce3c..7dd3193 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py @@ -22,7 +22,7 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"), ], - nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), + nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), ) pipe.enable_vram_management() diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py index c865271..25feb23 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py @@ -20,7 +20,7 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"), ], - nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), + nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), ) pipe.enable_vram_management()