[feature]:Add adaptation of all models to zero3

This commit is contained in:
feng0w0
2026-01-31 16:50:18 +08:00
parent 3140199c96
commit 2070bbd925
6 changed files with 58 additions and 144 deletions

View File

@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
from ..core.gradient import gradient_checkpoint_forward
def torch_dfs(model: nn.Module, parent_name='root'):
@@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(self.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
)
x = gradient_checkpoint_forward(
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x
)
x = x[:, :seq_len_x]
x = self.head(x, t[:-1])