mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
162 lines
6.8 KiB
Python
162 lines
6.8 KiB
Python
import torch
|
|
from PIL import Image
|
|
|
|
|
|
class NexusGenAutoregressiveModel(torch.nn.Module):
|
|
def __init__(self, max_length=1024, max_pixels=262640):
|
|
super(NexusGenAutoregressiveModel, self).__init__()
|
|
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
|
|
from transformers import Qwen2_5_VLConfig
|
|
self.max_length = max_length
|
|
self.max_pixels = max_pixels
|
|
model_config = Qwen2_5_VLConfig(**{
|
|
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
|
"architectures": [
|
|
"Qwen2_5_VLForConditionalGeneration"
|
|
],
|
|
"attention_dropout": 0.0,
|
|
"auto_map": {
|
|
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
|
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
|
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
|
},
|
|
"bos_token_id": 151643,
|
|
"eos_token_id": 151645,
|
|
"hidden_act": "silu",
|
|
"hidden_size": 3584,
|
|
"image_token_id": 151655,
|
|
"initializer_range": 0.02,
|
|
"intermediate_size": 18944,
|
|
"max_position_embeddings": 128000,
|
|
"max_window_layers": 28,
|
|
"model_type": "qwen2_5_vl",
|
|
"num_attention_heads": 28,
|
|
"num_hidden_layers": 28,
|
|
"num_key_value_heads": 4,
|
|
"pad_token_id": 151643,
|
|
"rms_norm_eps": 1e-06,
|
|
"rope_scaling": {
|
|
"mrope_section": [
|
|
16,
|
|
24,
|
|
24
|
|
],
|
|
"rope_type": "default",
|
|
"type": "default"
|
|
},
|
|
"rope_theta": 1000000.0,
|
|
"sliding_window": 32768,
|
|
"tie_word_embeddings": False,
|
|
"torch_dtype": "bfloat16",
|
|
"transformers_version": "4.49.0",
|
|
"use_cache": False,
|
|
"use_sliding_window": False,
|
|
"video_token_id": 151656,
|
|
"vision_config": {
|
|
"hidden_size": 1280,
|
|
"in_chans": 3,
|
|
"model_type": "qwen2_5_vl",
|
|
"spatial_patch_size": 14,
|
|
"tokens_per_second": 2,
|
|
"torch_dtype": "bfloat16"
|
|
},
|
|
"vision_end_token_id": 151653,
|
|
"vision_start_token_id": 151652,
|
|
"vision_token_id": 151654,
|
|
"vocab_size": 152064
|
|
})
|
|
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
|
self.processor = None
|
|
|
|
|
|
def load_processor(self, path):
|
|
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
|
|
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
|
|
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return NexusGenAutoregressiveModelStateDictConverter()
|
|
|
|
def bound_image(self, image, max_pixels=262640):
|
|
from qwen_vl_utils import smart_resize
|
|
resized_height, resized_width = smart_resize(
|
|
image.height,
|
|
image.width,
|
|
max_pixels=max_pixels,
|
|
)
|
|
return image.resize((resized_width, resized_height))
|
|
|
|
def get_editing_msg(self, instruction):
|
|
if '<image>' not in instruction:
|
|
instruction = '<image> ' + instruction
|
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
|
|
return messages
|
|
|
|
def get_generation_msg(self, instruction):
|
|
instruction = "Generate an image according to the following description: {}".format(instruction)
|
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
|
return messages
|
|
|
|
def forward(self, instruction, ref_image=None, num_img_tokens=81):
|
|
"""
|
|
Generate target embeddings for the given instruction and reference image.
|
|
"""
|
|
if ref_image is not None:
|
|
messages = self.get_editing_msg(instruction)
|
|
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
|
else:
|
|
messages = self.get_generation_msg(instruction)
|
|
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
|
|
|
return output_image_embeddings
|
|
|
|
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
|
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
|
|
inputs = processor(
|
|
text=[text],
|
|
images=images,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = inputs.to(model.device)
|
|
|
|
input_embeds = model.model.embed_tokens(inputs['input_ids'])
|
|
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
|
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
|
input_image_embeds = image_embeds[:-num_img_tokens]
|
|
|
|
image_mask = inputs['input_ids'] == model.config.image_token_id
|
|
indices = image_mask.cumsum(dim=1)
|
|
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
|
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
|
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
|
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
|
|
|
image_prefill_embeds = model.image_prefill_embeds(
|
|
torch.arange(81, device=model.device).long()
|
|
)
|
|
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
|
|
|
position_ids, _ = model.get_rope_index(
|
|
inputs['input_ids'],
|
|
inputs['image_grid_thw'],
|
|
attention_mask=inputs['attention_mask'])
|
|
position_ids = position_ids.contiguous()
|
|
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
|
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
|
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
|
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
|
|
|
|
|
|
class NexusGenAutoregressiveModelStateDictConverter:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_civitai(self, state_dict):
|
|
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
|
return state_dict
|