mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1272 from modelscope/zero3-fix
Support DeepSpeed ZeRO 3
This commit is contained in:
@@ -5,6 +5,7 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
@@ -379,27 +380,15 @@ class WanModel(torch.nn.Module):
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
if self.training:
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, freqs
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||
@@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module):
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
|
||||
)
|
||||
x = gradient_checkpoint_forward(
|
||||
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x
|
||||
)
|
||||
|
||||
x = x[:, :seq_len_x]
|
||||
x = self.head(x, t[:-1])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock
|
||||
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
class VaceWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
@@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module):
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.vace_blocks:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
c = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
c, x, context, t_mod, freqs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
@@ -171,7 +171,7 @@ class Resample(nn.Module):
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
return x + h, feat_cache, feat_idx
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
@@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module):
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
@@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module):
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
else:
|
||||
return x_main
|
||||
return x_main, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d_38(nn.Module):
|
||||
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
|
||||
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
|
||||
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
@@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module):
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||
@@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_):
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
@@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_):
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
first_chunk=True)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
Reference in New Issue
Block a user