from inspect import isfunction from math import log, pi import torch from einops import rearrange, repeat from torch import einsum, nn from typing import Any, Callable, List, Optional, Union from torch import Tensor import torch.nn.functional as F # helper functions def exists(val): return val is not None def broadcat(tensors, dim=-1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all( [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] ), "invalid dimensions for broadcastable concatentation" max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim=dim) # rotary embedding helper functions def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") def apply_rotary_emb(freqs, t, start_index=0): freqs = freqs.to(t) rot_dim = freqs.shape[-1] end_index = start_index + rot_dim assert ( rot_dim <= t.shape[-1] ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" t_left, t, t_right = ( t[..., :start_index], t[..., start_index:end_index], t[..., end_index:], ) t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) return torch.cat((t_left, t, t_right), dim=-1) # learned rotation helpers def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): if exists(freq_ranges): rotations = einsum("..., f -> ... f", rotations, freq_ranges) rotations = rearrange(rotations, "... r f -> ... (r f)") rotations = repeat(rotations, "... n -> ... (n r)", r=2) return apply_rotary_emb(rotations, t, start_index=start_index) # classes class WanToDanceRotaryEmbedding(nn.Module): def __init__( self, dim, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, learned_freq=False, ): super().__init__() if exists(custom_freqs): freqs = custom_freqs elif freqs_for == "lang": freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() else: raise ValueError(f"unknown modality {freqs_for}") self.cache = dict() if learned_freq: self.freqs = nn.Parameter(freqs) else: self.register_buffer("freqs", freqs, persistent=False) def rotate_queries_or_keys(self, t, seq_dim=-2): device = t.device seq_len = t.shape[seq_dim] freqs = self.forward( lambda: torch.arange(seq_len, device=device), cache_key=seq_len ) return apply_rotary_emb(freqs, t) def forward(self, t, cache_key=None): if exists(cache_key) and cache_key in self.cache: return self.cache[cache_key] if isfunction(t): t = t() # freqs = self.freqs freqs = self.freqs.to(t.device) freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs) freqs = repeat(freqs, "... n -> ... (n r)", r=2) if exists(cache_key): self.cache[cache_key] = freqs return freqs class WanToDanceMusicEncoderLayer(nn.Module): def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = True, device=None, dtype=None, rotary=None, ) -> None: super().__init__() self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype ) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm_first = norm_first self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = activation self.rotary = rotary self.use_rotary = rotary is not None # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] ) -> Tensor: qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x x = self.self_attn( qk, qk, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: x = src if self.norm_first: self.norm1.to(device=x.device) self.norm2.to(device=x.device) x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x