import os import re from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.generation import GenerationMixin, LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput from transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.modeling_outputs import ModelOutput from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLModel, Qwen2_5_VLPreTrainedModel, QWEN2_5_VL_INPUTS_DOCSTRING, ) from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput, VideoInput from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs from transformers.tokenization_utils_base import PreTokenizedInput, TextInput GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Qwen2_5_VLConfig" @dataclass class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): """ Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None image_embeddings: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] config_class = Qwen2_5_VLConfig _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] def __init__(self, config): super().__init__(config) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = Qwen2_5_VLModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.vision_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.rope_deltas = None # cache rope_deltas here self.image_prefill_embeds = nn.Embedding(81, config.hidden_size) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. Explanation: Each embedding sequence contains vision embedding and text embedding or just contains text embedding. For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. Examples: input_ids: [T T T T T], here T is for text. temporal position_ids: [0, 1, 2, 3, 4] height position_ids: [0, 1, 2, 3, 4] width position_ids: [0, 1, 2, 3, 4] For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part and 1D rotary position embedding for text part. Examples: Temporal (Time): 3 patches, representing different segments of the video in time. Height: 2 patches, dividing each frame vertically. Width: 2 patches, dividing each frame horizontally. We also have some important parameters: fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] text temporal position_ids: [101, 102, 103, 104, 105] text height position_ids: [101, 102, 103, 104, 105] text width position_ids: [101, 102, 103, 104, 105] Here we calculate the text start position_ids as the max vision position_ids plus 1. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. Returns: position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) """ spatial_merge_size = self.config.vision_config.spatial_merge_size image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) position_ids = torch.ones( 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device, ) image_index, video_index = 0, 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums for _ in range(image_nums + video_nums): if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: ed_image = len(input_tokens) + 1 if video_token_id in input_tokens and remain_videos > 0: ed_video = input_tokens.index(video_token_id, st) else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = ( image_grid_thw[image_index][0], image_grid_thw[image_index][1], image_grid_thw[image_index][2], ) second_per_grid_t = 0 image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = ( video_grid_thw[video_index][0], video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) if second_per_grid_ts is not None: second_per_grid_t = second_per_grid_ts[video_index] else: second_per_grid_t = 1.0 video_index += 1 remain_videos -= 1 ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t.item(), h.item() // spatial_merge_size, w.item() // spatial_merge_size, ) text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( torch.arange(input_ids.shape[1], device=input_ids.device) .view(1, 1, -1) .expand(3, input_ids.shape[0], -1) ) mrope_position_deltas = torch.zeros( [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype, ) return position_ids, mrope_position_deltas @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, image_embeddings: Optional[torch.Tensor] = None, token_loss_weight: Optional[float] = 0.1, img_loss_weight: Optional[float] = 1.0, ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") >>> messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "What is shown in this image?"}, ], }, ] >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: # test feature inputs_embeds = self.model.embed_tokens(input_ids) # for image encoding and training if pixel_values is not None: pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) mask = input_ids == self.config.image_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) mask = input_ids == self.config.video_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_mask = mask_expanded.to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # position_ids [3, B, L] outputs = self.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) image_embeds = self.vision_head(hidden_states) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues # prepare labels for logits logits_labels = labels.clone().detach() image_tokens = (labels == self.config.image_token_id) logits_labels[image_tokens] = -100 logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = logits_labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) * token_loss_weight shift_image_tokens_2d = (labels[..., 1:].contiguous() == self.config.image_token_id) # (B, L-1) shifted_image_embeds = image_embeds[:, :-1, :].contiguous() # (B, L-1, D) masked_image_embeds = shifted_image_embeds[shift_image_tokens_2d] # (num_image_tokens, D) mse_loss_fct = nn.MSELoss() mse_loss_fct = mse_loss_fct.to(shift_logits.device) if image_embeddings is None: image_embeddings = torch.zeros_like(masked_image_embeds) img_loss = mse_loss_fct(masked_image_embeds, image_embeddings) cos_sim = torch.cosine_similarity( masked_image_embeds, image_embeddings, dim=-1 ) cos_loss = (1 - cos_sim).mean() img_loss = 0.5 * img_loss + 0.5 * cos_loss # fix nan for empty image tokens if image_embeddings.size(0) == 0: img_loss = img_loss.nan_to_num(0.0) # combine the loss loss = loss + img_loss_weight * img_loss if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return Qwen2_5_VLCausalLMOutputWithPast( loss=loss, logits=logits, image_embeddings=image_embeds, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=self.rope_deltas, ) def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed to avoid deadlocking with `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate max_length = generation_config.max_length has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_forward = self.__call__ if isinstance(model_kwargs.get("past_key_values"), Cache): is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = is_compileable and not self.generation_config.disable_compile if is_compileable and ( self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices ): os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) is_prefill = True is_sampling_img = input_ids[:, -1] == self.config.vision_start_token_id generation_image_grid_thw = model_kwargs.pop("generation_image_grid_thw", self.get_default_image_grid_thw()) num_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) output_image_embeddings = [] while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare prefilled embeds model_inputs.update(self.prepare_prefilled_image_embeds(len(output_image_embeddings), num_img_tokens, is_sampling_img, **model_kwargs)) # parse position_ids from model_kwargs model_inputs.update(self.prepare_image_position_ids(input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs)) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) if is_prefill: outputs = self(**model_inputs, return_dict=True) is_prefill = False else: outputs = model_forward(**model_inputs, return_dict=True) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) # TODO: support batch image sampling if bool(is_sampling_img) and len(output_image_embeddings) < num_img_tokens: output_image_embeddings.append(outputs.image_embeddings[:, -1, :].unsqueeze(1)) if synced_gpus and this_peer_finished: continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_logits = next_token_logits.to(input_ids.device) # do not sample token next_token_logits[:, self.config.vision_end_token_id] = -float('inf') # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # while not bool(is_sampling_img) and torch.any(next_tokens == self.config.vision_end_token_id): # probs[:, self.config.vision_end_token_id] = 0 # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) #TODO: support batch image sample if num_img_tokens is not None: cur_img_tokens = (input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1) # check whether is sampling images is_end_img = torch.logical_and(cur_img_tokens == num_img_tokens, is_sampling_img) is_sampling_img = torch.logical_and(is_sampling_img, cur_img_tokens < num_img_tokens) next_tokens[is_sampling_img] = self.config.image_token_id # check whether to end sampling images next_tokens[is_end_img] = self.config.vision_end_token_id else: # check whether to end sampling images is_sampling_img = torch.logical_and(is_sampling_img, (next_tokens != self.config.vision_end_token_id)) # replace the next token with the image token if is sampling image next_tokens[is_sampling_img] = self.config.image_token_id # check whether to start sampling images is_sampling_img = torch.logical_or(is_sampling_img, (next_tokens == self.config.vision_start_token_id)) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 cur_len += 1 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs if streamer is not None: streamer.end() # output the image embeddings output_image_embeddings = torch.cat(output_image_embeddings, dim=1) if len(output_image_embeddings) > 0 else None if return_dict_in_generate: return GenerateDecoderOnlyAll2AllOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), output_image_embeddings=output_image_embeddings, ) else: return input_ids def prepare_prefilled_image_embeds(self, cur_image_tokens, num_img_tokens, is_sampling_img, **model_kwargs): if cur_image_tokens == 0 or cur_image_tokens > num_img_tokens or not bool(is_sampling_img): return {} # TODO: support batch image sample image_idx = torch.tensor([cur_image_tokens-1]).to(self.device).long().unsqueeze(0) inputs_embeds = self.image_prefill_embeds(image_idx) return {"inputs_embeds": inputs_embeds} def get_default_image_grid_thw(self,): return torch.tensor([[1, 18, 18]]).to(self.device) def get_num_image_tokens(self, image_grid_thw): return int(torch.prod(image_grid_thw, dim=1).sum() // 4) def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): num_img_tokens = model_kwargs.pop("generation_image_grid_thw", None) super()._validate_model_kwargs(model_kwargs) model_kwargs["generation_image_grid_thw"] = num_img_tokens def prepare_image_position_ids(self, input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs): # Overwritten -- prepare position_ids for image tokens cur_img_tokens = int((input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1)) # TODO: support batch image sample if cur_img_tokens > 0 and bool(is_sampling_img): image_grid_thw = generation_image_grid_thw if model_kwargs.get('image_grid_thw') is not None: image_grid_thw = torch.cat([model_kwargs.get('image_grid_thw'), image_grid_thw]) remaining_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) - cur_img_tokens padding_ids = input_ids.new_full((1, remaining_img_tokens), fill_value=self.config.image_token_id) padded_ids = torch.cat([input_ids, padding_ids], dim=1) position_ids, _ = self.get_rope_index(padded_ids, image_grid_thw, None, None) if model_kwargs.get("use_cache", True): position_ids = position_ids[:, :, input_ids.shape[1] - 1].unsqueeze(-1) else: position_ids = position_ids[:, :, :input_ids.shape[1]] return {"position_ids": position_ids} return {} def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, second_per_grid_ts=None, image_embeddings=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, use_cache=use_cache, **kwargs, ) # Qwen2-5-VL position_ids are prepared with rope_deltas in forward model_inputs["position_ids"] = None if cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_values_videos"] = None return model_inputs def _get_image_nums_and_video_nums( self, input_ids: Optional[torch.LongTensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Returns: image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) """ image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id vision_start_mask = input_ids == vision_start_token_id vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) image_mask = input_ids == image_token_id video_mask = input_ids == video_token_id image_nums = torch.sum(vision_first_mask & image_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1) return image_nums, video_nums def _expand_inputs_for_generation( self, expand_size: int = 1, is_encoder_decoder: bool = False, input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: # Overwritten -- Support for expanding tensors without a batch size dimension # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t # pixel_values.shape[0] is sum(seqlen_images for samples) # image_grid_thw.shape[0] is sum(num_images for samples) if expand_size == 1: return input_ids, model_kwargs visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None) image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) repeat_args = [repeat_times] + [1] * (x.dim() - 1) result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) return result for key in dict_to_expand: if key == "pixel_values": # split images into samples samples = torch.split(image_grid_thw, list(image_nums)) # compute the sequence length of images for each sample lengths = [torch.prod(sample, dim=1).sum() for sample in samples] dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "image_grid_thw": # get the num of images for each sample lengths = list(image_nums) dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "pixel_values_videos": samples = torch.split(video_grid_thw, list(video_nums)) lengths = [torch.prod(sample, dim=1).sum() for sample in samples] dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "video_grid_thw": lengths = list(video_nums) dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "second_per_grid_ts": if not isinstance(dict_to_expand[key], list): raise TypeError( f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." ) tensor = torch.tensor(dict_to_expand[key]) lengths = list(video_nums) tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) dict_to_expand[key] = tensor.tolist() return dict_to_expand def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: if ( key != "cache_position" and dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) and key not in visual_keys ): dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand # input_ids is required for expanding visual inputs # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. if input_ids is not None and input_ids.numel() != 0: model_kwargs = _expand_dict_for_generation_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) model_kwargs = _expand_dict_for_generation(model_kwargs) if is_encoder_decoder: if model_kwargs.get("encoder_outputs") is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs __all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): fps: Union[List[float], float] class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): videos_kwargs: Qwen2_5_VLVideosProcessorKwargs _defaults = { "text_kwargs": { "padding": False, }, "videos_kwargs": {"fps": 2.0}, } class Qwen2_5_VLProcessor(ProcessorMixin): r""" Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. Args: image_processor ([`Qwen2VLImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`Qwen2TokenizerFast`], *optional*): The tokenizer is a required input. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] valid_kwargs = ["chat_template"] image_processor_class = "AutoImageProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( self, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, videos: VideoInput = None, **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. """ output_kwargs = self._merge_kwargs( Qwen2_5_VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if images is not None: image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] else: image_inputs = {} image_grid_thw = None if videos is not None: videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) if isinstance(fps, (int, float)): second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] else: raise ValueError( f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." ) videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) else: videos_inputs = {} video_grid_thw = None if not isinstance(text, list): text = [text] if image_grid_thw is not None: merge_length = self.image_processor.merge_size**2 index = 0 for i in range(len(text)): while self.image_token in text[i]: text[i] = text[i].replace( self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1, ) index += 1 text[i] = text[i].replace("<|placeholder|>", self.image_token) if video_grid_thw is not None: merge_length = self.image_processor.merge_size**2 index = 0 for i in range(len(text)): while self.video_token in text[i]: text[i] = text[i].replace( self.video_token, "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1, ) index += 1 text[i] = text[i].replace("<|placeholder|>", self.video_token) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def batch_decode_all2all(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ decoded = self.tokenizer.batch_decode(*args, **kwargs) pattern = r'<\|vision_start\|>.*?<\|vision_end\|>' decoded_with_image_tag = [re.sub(pattern, '', d, flags=re.DOTALL) for d in decoded] decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag] return decoded_with_image_tag def decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) def post_process_image_text_to_text( self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs ): """ Post-process the output of the model to decode the text. Args: generated_outputs (`torch.Tensor` or `np.ndarray`): The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` or `(sequence_length,)`. skip_special_tokens (`bool`, *optional*, defaults to `True`): Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. **kwargs: Additional arguments to be passed to the tokenizer's `batch_decode method`. Returns: `List[str]`: The decoded text. """ return self.tokenizer.batch_decode( generated_outputs, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs, ) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) return names_from_processor + ["second_per_grid_ts"] __all__ = ["Qwen2_5_VLProcessor"]