mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
vram optimization
This commit is contained in:
@@ -291,17 +291,21 @@ class WanModel(torch.nn.Module):
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
@@ -315,11 +319,19 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
return super().forward(x).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# compute attention
|
||||
p = self.attn_dropout if self.training else 0.0
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||
x = x.reshape(b, s, c)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, version=2)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
x = x.reshape(b, 1, c)
|
||||
|
||||
# output
|
||||
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||
videos = videos.to(dtype)
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
return out
|
||||
|
||||
|
||||
@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float().clamp_(-1, 1)
|
||||
values = values.clamp_(-1, 1)
|
||||
return values
|
||||
|
||||
|
||||
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float()
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, self.scale)
|
||||
return x.float()
|
||||
return x
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, self.scale)
|
||||
return video.float().clamp_(-1, 1)
|
||||
return video.clamp_(-1, 1)
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
|
||||
Reference in New Issue
Block a user