From 2070bbd925daeac75f9c27bfb79d2b924e0fa936 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Sat, 31 Jan 2026 16:50:18 +0800 Subject: [PATCH] [feature]:Add adaptation of all models to zero3 --- diffsynth/models/wan_video_animate_adapter.py | 2 +- diffsynth/models/wan_video_dit.py | 27 +++----- diffsynth/models/wan_video_dit_s2v.py | 52 ++++----------- diffsynth/models/wan_video_vace.py | 29 +++------ diffsynth/pipelines/wan_video.py | 65 ++++++------------- .../utils/xfuser/xdit_context_parallel.py | 27 +++----- 6 files changed, 58 insertions(+), 144 deletions(-) diff --git a/diffsynth/models/wan_video_animate_adapter.py b/diffsynth/models/wan_video_animate_adapter.py index 3ace70d..8873aff 100644 --- a/diffsynth/models/wan_video_animate_adapter.py +++ b/diffsynth/models/wan_video_animate_adapter.py @@ -607,7 +607,7 @@ class Generator(nn.Module): def get_motion(self, img): #motion_feat = self.enc.enc_motion(img) - motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True, determinism_check="none") motion = self.dec.direction(motion_feat) return motion diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 7386223..c11a1bc 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,6 +5,7 @@ import math from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter +from ..core.gradient import gradient_checkpoint_forward try: import flash_attn_interface @@ -379,27 +380,15 @@ class WanModel(torch.nn.Module): self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward for block in self.blocks: - if self.training and use_gradient_checkpointing: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) else: x = block(x, context, t_mod, freqs) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 8fbed8c..f4d1abe 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d +from ..core.gradient import gradient_checkpoint_forward def torch_dfs(model: nn.Module, parent_name='root'): @@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module): t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - for block_id, block in enumerate(self.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - context, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - context, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] + ) + x = gradient_checkpoint_forward( + lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) x = x[:, :seq_len_x] x = self.head(x, t[:-1]) diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py index f3367f7..0e13183 100644 --- a/diffsynth/models/wan_video_vace.py +++ b/diffsynth/models/wan_video_vace.py @@ -1,6 +1,6 @@ import torch from .wan_video_dit import DiTBlock - +from ..core.gradient import gradient_checkpoint_forward class VaceWanAttentionBlock(DiTBlock): def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): @@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module): dim=1) for u in c ]) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - for block in self.vace_blocks: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - c = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - c, x, context, t_mod, freqs, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - c = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - c, x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - c = block(c, x, context, t_mod, freqs) + c = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, x, context, t_mod, freqs + ) + hints = torch.unbind(c)[:-1] return hints diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index edd6dff..f568701 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1321,11 +1321,6 @@ def model_fn_wan_video( if tea_cache_update: x = tea_cache.update(x) else: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - def create_custom_forward_vap(block, vap): def custom_forward(*inputs): return vap(block, *inputs) @@ -1340,31 +1335,25 @@ def model_fn_wan_video( create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, use_reentrant=False, + determinism_check="none" ) elif use_gradient_checkpointing: x, x_vap = torch.utils.checkpoint.checkpoint( create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, use_reentrant=False, + determinism_check="none" ) else: x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) else: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, freqs) + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) + # VACE if vace_context is not None and block_id in vace.vace_layers_mapping: @@ -1487,32 +1476,18 @@ def model_fn_wans2v( return custom_forward for block_id, block in enumerate(dit.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + x = gradient_checkpoint_forward( + lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 21dc3b3..cea55e4 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -6,6 +6,7 @@ from xfuser.core.distributed import (get_sequence_parallel_rank, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ...core.device import parse_nccl_backend, parse_device_type +from ...core.gradient import gradient_checkpoint_forward def initialize_usp(device_type): @@ -81,11 +82,6 @@ def usp_dit_forward(self, self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward # Context Parallel chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) @@ -94,20 +90,13 @@ def usp_dit_forward(self, x = chunks[get_sequence_parallel_rank()] for block in self.blocks: - if self.training and use_gradient_checkpointing: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) else: x = block(x, context, t_mod, freqs)