From ca9b5e64eafeab61b0f5427bf121c0c708655183 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Tue, 3 Feb 2026 15:44:53 +0800 Subject: [PATCH] [feature]:Add adaptation of all models to zero3 --- diffsynth/core/gradient/gradient_checkpoint.py | 2 -- diffsynth/models/wan_video_animate_adapter.py | 2 +- diffsynth/pipelines/flux2_image.py | 13 ++++++------- diffsynth/pipelines/wan_video.py | 6 ++---- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/diffsynth/core/gradient/gradient_checkpoint.py b/diffsynth/core/gradient/gradient_checkpoint.py index d252573..b356415 100644 --- a/diffsynth/core/gradient/gradient_checkpoint.py +++ b/diffsynth/core/gradient/gradient_checkpoint.py @@ -21,7 +21,6 @@ def gradient_checkpoint_forward( *args, **kwargs, use_reentrant=False, - determinism_check="none" ) elif use_gradient_checkpointing: model_output = torch.utils.checkpoint.checkpoint( @@ -29,7 +28,6 @@ def gradient_checkpoint_forward( *args, **kwargs, use_reentrant=False, - determinism_check="none" ) else: model_output = model(*args, **kwargs) diff --git a/diffsynth/models/wan_video_animate_adapter.py b/diffsynth/models/wan_video_animate_adapter.py index 8873aff..3ace70d 100644 --- a/diffsynth/models/wan_video_animate_adapter.py +++ b/diffsynth/models/wan_video_animate_adapter.py @@ -607,7 +607,7 @@ class Generator(nn.Module): def get_motion(self, img): #motion_feat = self.enc.enc_motion(img) - motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True, determinism_check="none") + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) motion = self.dec.direction(motion_feat) return motion diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index c68dcb9..bea6b7c 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): attention_mask = torch.cat(all_attention_masks, dim=0).to(device) # Forward pass through the model - with torch.no_grad(): - output = text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - use_cache=False, - ) + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) # Only use outputs from intermediate layers and stack them out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index f568701..bbc479e 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1334,15 +1334,13 @@ def model_fn_wan_video( x, x_vap = torch.utils.checkpoint.checkpoint( create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, - use_reentrant=False, - determinism_check="none" + use_reentrant=False ) elif use_gradient_checkpointing: x, x_vap = torch.utils.checkpoint.checkpoint( create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, - use_reentrant=False, - determinism_check="none" + use_reentrant=False ) else: x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)