mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
train
This commit is contained in:
@@ -1181,13 +1181,18 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
@@ -1196,7 +1201,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# layer_outputs = self._gradient_checkpointing_func(
|
||||
# decoder_layer.__call__,
|
||||
# hidden_states,
|
||||
# causal_mask,
|
||||
# position_ids,
|
||||
# past_key_values,
|
||||
# output_attentions,
|
||||
# use_cache,
|
||||
# cache_position,
|
||||
# position_embeddings,
|
||||
# )
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user