mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
[feature]:Add adaptation of all models to zero3
This commit is contained in:
@@ -171,7 +171,7 @@ class Resample(nn.Module):
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
return x + h, feat_cache, feat_idx
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
@@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module):
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
@@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module):
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
else:
|
||||
return x_main
|
||||
return x_main, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d_38(nn.Module):
|
||||
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
|
||||
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
|
||||
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
Reference in New Issue
Block a user