mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
ace-step train
This commit is contained in:
@@ -864,20 +864,13 @@ class AceStepDiTModel(nn.Module):
|
||||
layer_kwargs = flash_attn_kwargs
|
||||
|
||||
# Use gradient checkpointing if enabled
|
||||
if use_gradient_checkpointing or use_gradient_checkpointing_offload:
|
||||
layer_outputs = gradient_checkpoint_forward(
|
||||
layer_module,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*layer_args,
|
||||
**layer_kwargs,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
*layer_args,
|
||||
**layer_kwargs,
|
||||
)
|
||||
|
||||
layer_outputs = gradient_checkpoint_forward(
|
||||
layer_module,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*layer_args,
|
||||
**layer_kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions and self.layers[index_block].use_cross_attention:
|
||||
|
||||
Reference in New Issue
Block a user