Merge pull request #1272 from modelscope/zero3-fix

Support DeepSpeed ZeRO 3
This commit is contained in:
Zhongjie Duan
2026-02-06 16:33:12 +08:00
committed by GitHub
26 changed files with 353 additions and 188 deletions

View File

@@ -3,14 +3,14 @@ from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
from contextlib import contextmanager
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
with skip_model_initialization():
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
@@ -48,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
# Because at this stage, model parameters are partitioned across multiple GPUs.
# Loading them directly could lead to excessive GPU memory consumption.
if is_deepspeed_zero3_enabled():
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
_load_state_dict_into_zero3_model(model, state_dict)
else:
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`?
# Because some models override the behavior of `to()`,
# especially those from libraries like Transformers.
@@ -79,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
}
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model
def get_init_context(torch_dtype, device):
if is_deepspeed_zero3_enabled():
from transformers.modeling_utils import set_zero3_state
import deepspeed
# Why do we use "deepspeed.zero.Init"?
# Weight segmentation of the model can be performed on the CPU side
# and loading the segmented weights onto the computing card
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
else:
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
init_contexts = [skip_model_initialization()]
return init_contexts

View File

@@ -18,8 +18,8 @@ class ModelLogger:
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
@@ -34,8 +34,8 @@ class ModelLogger:
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)

View File

@@ -27,7 +27,7 @@ def launch_training_task(
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs):
@@ -59,6 +59,7 @@ def launch_data_process_task(
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)):

View File

@@ -5,6 +5,7 @@ import math
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward
try:
import flash_attn_interface
@@ -379,27 +380,15 @@ class WanModel(torch.nn.Module):
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
if self.training:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs
)
else:
x = block(x, context, t_mod, freqs)

View File

@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
from ..core.gradient import gradient_checkpoint_forward
def torch_dfs(model: nn.Module, parent_name='root'):
@@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(self.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
)
x = gradient_checkpoint_forward(
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x
)
x = x[:, :seq_len_x]
x = self.head(x, t[:-1])

View File

@@ -1,6 +1,6 @@
import torch
from .wan_video_dit import DiTBlock
from ..core.gradient import gradient_checkpoint_forward
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
@@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module):
dim=1) for u in c
])
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.vace_blocks:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)
hints = torch.unbind(c)[:-1]
return hints

View File

@@ -171,7 +171,7 @@ class Resample(nn.Module):
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
return x, feat_cache, feat_idx
def init_weight(self, conv):
conv_weight = conv.weight
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x + h
return x + h, feat_cache, feat_idx
class AttentionBlock(nn.Module):
@@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module):
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
class Up_ResidualBlock(nn.Module):
@@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module):
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
return x_main, feat_cache, feat_idx
class Encoder3d(nn.Module):
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
class Encoder3d_38(nn.Module):
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
class Decoder3d(nn.Module):
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
def count_conv3d(model):
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
@@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
@@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
@@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)

View File

@@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
with torch.inference_mode():
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)

View File

@@ -1321,11 +1321,6 @@ def model_fn_wan_video(
if tea_cache_update:
x = tea_cache.update(x)
else:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
def create_custom_forward_vap(block, vap):
def custom_forward(*inputs):
return vap(block, *inputs)
@@ -1339,32 +1334,24 @@ def model_fn_wan_video(
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
use_reentrant=False
)
elif use_gradient_checkpointing:
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
use_reentrant=False
)
else:
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
else:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs
)
# VACE
if vace_context is not None and block_id in vace.vace_layers_mapping:
@@ -1487,32 +1474,18 @@ def model_fn_wans2v(
return custom_forward
for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
x = gradient_checkpoint_forward(
lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x
)
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)

View File

@@ -9,6 +9,7 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ... import IS_NPU_AVAILABLE
from ...core.device import parse_nccl_backend, parse_device_type
from ...core.gradient import gradient_checkpoint_forward
def initialize_usp(device_type):
@@ -87,11 +88,6 @@ def usp_dit_forward(self,
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Context Parallel
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
@@ -100,20 +96,13 @@ def usp_dit_forward(self,
x = chunks[get_sequence_parallel_rank()]
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
if self.training:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs
)
else:
x = block(x, context, t_mod, freqs)