[feature]:Add adaptation of all models to zero3

This commit is contained in:
feng0w0
2026-02-03 15:44:53 +08:00
parent 2070bbd925
commit ca9b5e64ea
4 changed files with 9 additions and 14 deletions

View File

@@ -21,7 +21,6 @@ def gradient_checkpoint_forward(
*args, *args,
**kwargs, **kwargs,
use_reentrant=False, use_reentrant=False,
determinism_check="none"
) )
elif use_gradient_checkpointing: elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint( model_output = torch.utils.checkpoint.checkpoint(
@@ -29,7 +28,6 @@ def gradient_checkpoint_forward(
*args, *args,
**kwargs, **kwargs,
use_reentrant=False, use_reentrant=False,
determinism_check="none"
) )
else: else:
model_output = model(*args, **kwargs) model_output = model(*args, **kwargs)

View File

@@ -607,7 +607,7 @@ class Generator(nn.Module):
def get_motion(self, img): def get_motion(self, img):
#motion_feat = self.enc.enc_motion(img) #motion_feat = self.enc.enc_motion(img)
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True, determinism_check="none") motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
motion = self.dec.direction(motion_feat) motion = self.dec.direction(motion_feat)
return motion return motion

View File

@@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
attention_mask = torch.cat(all_attention_masks, dim=0).to(device) attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model # Forward pass through the model
with torch.no_grad(): output = text_encoder(
output = text_encoder( input_ids=input_ids,
input_ids=input_ids, attention_mask=attention_mask,
attention_mask=attention_mask, output_hidden_states=True,
output_hidden_states=True, use_cache=False,
use_cache=False, )
)
# Only use outputs from intermediate layers and stack them # Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)

View File

@@ -1334,15 +1334,13 @@ def model_fn_wan_video(
x, x_vap = torch.utils.checkpoint.checkpoint( x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap), create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False, use_reentrant=False
determinism_check="none"
) )
elif use_gradient_checkpointing: elif use_gradient_checkpointing:
x, x_vap = torch.utils.checkpoint.checkpoint( x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap), create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False, use_reentrant=False
determinism_check="none"
) )
else: else:
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)