This commit is contained in:
xuyixuan.xyx
2025-05-07 11:22:13 +08:00
parent 290ec469ca
commit f17558a4c4
4 changed files with 47 additions and 21 deletions

View File

@@ -1181,13 +1181,18 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
all_self_attns = () if output_attentions else None
next_decoder_cache = None
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_mask,
position_ids,
@@ -1196,7 +1201,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
use_cache,
cache_position,
position_embeddings,
use_reentrant=False,
)
# layer_outputs = self._gradient_checkpointing_func(
# decoder_layer.__call__,
# hidden_states,
# causal_mask,
# position_ids,
# past_key_values,
# output_attentions,
# use_cache,
# cache_position,
# position_embeddings,
# )
else:
layer_outputs = decoder_layer(
hidden_states,