mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1272 from modelscope/zero3-fix
Support DeepSpeed ZeRO 3
This commit is contained in:
@@ -3,14 +3,14 @@ from ..vram.disk_map import DiskMap
|
||||
from ..vram.layers import enable_vram_management
|
||||
from .file import load_state_dict
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
with skip_model_initialization():
|
||||
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||
model = model_class(**config)
|
||||
# What is `module_map`?
|
||||
# This is a module mapping table for VRAM management.
|
||||
@@ -48,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||
# Loading them directly could lead to excessive GPU memory consumption.
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
_load_state_dict_into_zero3_model(model, state_dict)
|
||||
else:
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why do we call `to()`?
|
||||
# Because some models override the behavior of `to()`,
|
||||
# especially those from libraries like Transformers.
|
||||
@@ -79,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
|
||||
}
|
||||
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||
return model
|
||||
|
||||
|
||||
def get_init_context(torch_dtype, device):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.modeling_utils import set_zero3_state
|
||||
import deepspeed
|
||||
# Why do we use "deepspeed.zero.Init"?
|
||||
# Weight segmentation of the model can be performed on the CPU side
|
||||
# and loading the segmented weights onto the computing card
|
||||
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||
else:
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
init_contexts = [skip_model_initialization()]
|
||||
|
||||
return init_contexts
|
||||
|
||||
@@ -18,8 +18,8 @@ class ModelLogger:
|
||||
|
||||
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
@@ -34,8 +34,8 @@ class ModelLogger:
|
||||
|
||||
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
|
||||
@@ -27,7 +27,7 @@ def launch_training_task(
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
|
||||
model.to(device=accelerator.device)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch_id in range(num_epochs):
|
||||
@@ -59,6 +59,7 @@ def launch_data_process_task(
|
||||
num_workers = args.dataset_num_workers
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
model.to(device=accelerator.device)
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
|
||||
@@ -5,6 +5,7 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
@@ -379,27 +380,15 @@ class WanModel(torch.nn.Module):
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
if self.training:
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, freqs
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||
@@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module):
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
|
||||
)
|
||||
x = gradient_checkpoint_forward(
|
||||
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x
|
||||
)
|
||||
|
||||
x = x[:, :seq_len_x]
|
||||
x = self.head(x, t[:-1])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock
|
||||
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
class VaceWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
@@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module):
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.vace_blocks:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
c = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
c, x, context, t_mod, freqs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
@@ -171,7 +171,7 @@ class Resample(nn.Module):
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
return x + h, feat_cache, feat_idx
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
@@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module):
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
@@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module):
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
else:
|
||||
return x_main
|
||||
return x_main, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d_38(nn.Module):
|
||||
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
|
||||
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
|
||||
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
@@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module):
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||
@@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_):
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
@@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_):
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
first_chunk=True)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
@@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
with torch.inference_mode():
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
|
||||
|
||||
@@ -1321,11 +1321,6 @@ def model_fn_wan_video(
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
def create_custom_forward_vap(block, vap):
|
||||
def custom_forward(*inputs):
|
||||
return vap(block, *inputs)
|
||||
@@ -1339,32 +1334,24 @@ def model_fn_wan_video(
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
use_reentrant=False
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
|
||||
else:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, freqs
|
||||
)
|
||||
|
||||
|
||||
# VACE
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
@@ -1487,32 +1474,18 @@ def model_fn_wans2v(
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
|
||||
x = gradient_checkpoint_forward(
|
||||
lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x
|
||||
)
|
||||
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
@@ -9,6 +9,7 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
from ... import IS_NPU_AVAILABLE
|
||||
from ...core.device import parse_nccl_backend, parse_device_type
|
||||
from ...core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
def initialize_usp(device_type):
|
||||
@@ -87,11 +88,6 @@ def usp_dit_forward(self,
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
# Context Parallel
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||
@@ -100,20 +96,13 @@ def usp_dit_forward(self,
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
if self.training:
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, freqs
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user