support video-to-video-translation

This commit is contained in:
Artiprocher
2023-12-21 17:11:58 +08:00
parent f7f4c1038e
commit c1453281df
20 changed files with 1659 additions and 427 deletions

View File

@@ -1,4 +1,15 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
@@ -15,7 +26,7 @@ class Attention(torch.nn.Module):
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
@@ -36,3 +47,30 @@ class Attention(torch.nn.Module):
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask)