vram optimization

This commit is contained in:
Artiprocher
2025-03-10 17:11:11 +08:00
parent b548d7caf2
commit a05f647633
6 changed files with 83 additions and 51 deletions

View File

@@ -291,17 +291,21 @@ class WanModel(torch.nn.Module):
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
@@ -315,11 +319,19 @@ class WanModel(torch.nn.Module):
for block in self.blocks:
if self.training and use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)

View File

@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
return super().forward(x).type_as(x)
class SelfAttention(nn.Module):
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
x = flash_attention(q, k, v, num_heads=self.num_heads)
# output
x = self.proj(x)
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
k, v = self.to_kv(x).chunk(2, dim=-1)
# compute attention
x = flash_attention(q, k, v, version=2)
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = x.reshape(b, 1, c)
# output
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
dtype = next(iter(self.model.visual.parameters())).dtype
videos = videos.to(dtype)
out = self.model.visual(videos, use_31_block=True)
return out

View File

@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float().clamp_(-1, 1)
values = values.clamp_(-1, 1)
return values
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float()
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x.float()
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.float().clamp_(-1, 1)
return video.clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):