mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
28 lines
802 B
Python
28 lines
802 B
Python
import torch
|
|
import torch.nn as nn
|
|
from .wan_video_dit import sinusoidal_embedding_1d
|
|
|
|
|
|
|
|
class WanMotionControllerModel(torch.nn.Module):
|
|
def __init__(self, freq_dim=256, dim=1536):
|
|
super().__init__()
|
|
self.freq_dim = freq_dim
|
|
self.linear = nn.Sequential(
|
|
nn.Linear(freq_dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim * 6),
|
|
)
|
|
|
|
def forward(self, motion_bucket_id):
|
|
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
|
emb = self.linear(emb)
|
|
return emb
|
|
|
|
def init(self):
|
|
state_dict = self.linear[-1].state_dict()
|
|
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
|
self.linear[-1].load_state_dict(state_dict)
|