support motion controller

This commit is contained in:
Artiprocher
2025-03-24 19:07:58 +08:00
parent 105eaf0f49
commit 05094710e3
3 changed files with 742 additions and 6 deletions

View File

@@ -0,0 +1,27 @@
import torch
import torch.nn as nn
from .wan_video_dit import sinusoidal_embedding_1d
class WanMotionControllerModel(torch.nn.Module):
def __init__(self, freq_dim=256, dim=1536):
super().__init__()
self.freq_dim = freq_dim
self.linear = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6),
)
def forward(self, motion_bucket_id):
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
emb = self.linear(emb)
return emb
def init(self):
state_dict = self.linear[-1].state_dict()
state_dict = {i: state_dict[i] * 0 for i in state_dict}
self.linear[-1].load_state_dict(state_dict)

View File

@@ -18,6 +18,7 @@ from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_controlnet import WanControlNetModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -32,7 +33,8 @@ class WanVideoPipeline(BasePipeline):
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.controlnet: WanControlNetModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet']
self.motion_controller: WanMotionControllerModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet', 'motion_controller']
self.height_division_factor = 16
self.width_division_factor = 16
@@ -196,6 +198,11 @@ class WanVideoPipeline(BasePipeline):
def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return {"controlnet_conditioning": controlnet_conditioning}
def prepare_motion_bucket_id(self, motion_bucket_id):
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
return {"motion_bucket_id": motion_bucket_id}
@torch.no_grad()
@@ -214,6 +221,7 @@ class WanVideoPipeline(BasePipeline):
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
motion_bucket_id=None,
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
@@ -269,6 +277,12 @@ class WanVideoPipeline(BasePipeline):
else:
controlnet_kwargs = {}
# Motion Controller
if self.motion_controller is not None and motion_bucket_id is not None:
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
else:
motion_kwargs = {}
# Extra input
extra_input = self.prepare_extra_input(latents)
@@ -277,23 +291,23 @@ class WanVideoPipeline(BasePipeline):
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(["dit", "controlnet"])
self.load_models_to_device(["dit", "controlnet", "motion_controller"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(
self.dit, controlnet=self.controlnet,
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **controlnet_kwargs
**tea_cache_posi, **controlnet_kwargs, **motion_kwargs,
)
if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(
self.dit, controlnet=self.controlnet,
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **controlnet_kwargs
**tea_cache_nega, **controlnet_kwargs, **motion_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -368,6 +382,7 @@ class TeaCache:
def model_fn_wan_video(
dit: WanModel,
controlnet: WanControlNetModel = None,
motion_controller: WanMotionControllerModel = None,
x: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
@@ -375,6 +390,7 @@ def model_fn_wan_video(
y: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
controlnet_conditioning: Optional[torch.Tensor] = None,
motion_bucket_id: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
@@ -392,6 +408,8 @@ def model_fn_wan_video(
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
if dit.has_image_input: