[feature]:Add adaptation of all models to zero3

This commit is contained in:
feng0w0
2026-02-03 15:44:53 +08:00
parent 2070bbd925
commit ca9b5e64ea
4 changed files with 9 additions and 14 deletions

View File

@@ -21,7 +21,6 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
@@ -29,7 +28,6 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
else:
model_output = model(*args, **kwargs)

View File

@@ -607,7 +607,7 @@ class Generator(nn.Module):
def get_motion(self, img):
#motion_feat = self.enc.enc_motion(img)
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True, determinism_check="none")
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
motion = self.dec.direction(motion_feat)
return motion

View File

@@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
with torch.no_grad():
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)

View File

@@ -1334,15 +1334,13 @@ def model_fn_wan_video(
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
determinism_check="none"
use_reentrant=False
)
elif use_gradient_checkpointing:
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
determinism_check="none"
use_reentrant=False
)
else:
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)