mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .wan_video_dit import sinusoidal_embedding_1d
|
|
|
|
|
|
|
|
class WanMotionControllerModel(torch.nn.Module):
|
|
def __init__(self, freq_dim=256, dim=1536):
|
|
super().__init__()
|
|
self.freq_dim = freq_dim
|
|
self.linear = nn.Sequential(
|
|
nn.Linear(freq_dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim * 6),
|
|
)
|
|
|
|
def forward(self, motion_bucket_id):
|
|
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
|
emb = self.linear(emb)
|
|
return emb
|
|
|
|
def init(self):
|
|
state_dict = self.linear[-1].state_dict()
|
|
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
|
self.linear[-1].load_state_dict(state_dict)
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return WanMotionControllerModelDictConverter()
|
|
|
|
|
|
|
|
class WanMotionControllerModelDictConverter:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_diffusers(self, state_dict):
|
|
return state_dict
|
|
|
|
def from_civitai(self, state_dict):
|
|
return state_dict
|
|
|