vram optimization

This commit is contained in:
Artiprocher
2025-03-10 17:11:11 +08:00
parent b548d7caf2
commit a05f647633
6 changed files with 83 additions and 51 deletions

View File

@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
return super().forward(x).type_as(x)
class SelfAttention(nn.Module):
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
x = flash_attention(q, k, v, num_heads=self.num_heads)
# output
x = self.proj(x)
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
k, v = self.to_kv(x).chunk(2, dim=-1)
# compute attention
x = flash_attention(q, k, v, version=2)
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = x.reshape(b, 1, c)
# output
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
dtype = next(iter(self.model.visual.parameters())).dtype
videos = videos.to(dtype)
out = self.model.visual(videos, use_31_block=True)
return out