mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:32:27 +00:00
vram optimization
This commit is contained in:
@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
return super().forward(x).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# compute attention
|
||||
p = self.attn_dropout if self.training else 0.0
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||
x = x.reshape(b, s, c)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, version=2)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
x = x.reshape(b, 1, c)
|
||||
|
||||
# output
|
||||
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||
videos = videos.to(dtype)
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user