mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
controlnet
This commit is contained in:
204
diffsynth/models/wan_video_controlnet.py
Normal file
204
diffsynth/models/wan_video_controlnet.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
|
||||
class WanControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: Tuple[int, int, int],
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
self.blocks = nn.ModuleList([
|
||||
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
head_dim = dim // num_heads
|
||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
|
||||
|
||||
self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
||||
self.controlnet_blocks = torch.nn.ModuleList([
|
||||
torch.nn.Linear(dim, dim, bias=False)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
||||
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
controlnet_conditioning: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x = x + self.controlnet_conv_in(controlnet_conditioning)
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
res_stack = []
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
res_stack.append(x)
|
||||
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
return controlnet_res_stack
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanControlNetModelStateDictConverter()
|
||||
|
||||
|
||||
class WanControlNetModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_base_model(self, state_dict):
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
state_dict_ = {}
|
||||
dtype, device = None, None
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("head."):
|
||||
continue
|
||||
state_dict_[name] = param
|
||||
dtype, device = param.dtype, param.device
|
||||
for block_id in range(config["num_layers"]):
|
||||
zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device)
|
||||
state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone()
|
||||
state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device)
|
||||
state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device)
|
||||
return state_dict_, config
|
||||
Reference in New Issue
Block a user