controlnet

This commit is contained in:
Artiprocher
2025-03-21 11:09:56 +08:00
parent 6cd032e846
commit 105eaf0f49
4 changed files with 915 additions and 12 deletions

View File

@@ -0,0 +1,204 @@
import torch
import torch.nn as nn
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d
from .utils import hash_state_dict_keys
class WanControlNetModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
for _ in range(num_layers)
])
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Linear(dim, dim, bias=False)
for _ in range(num_layers)
])
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
controlnet_conditioning: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x = x + self.controlnet_conv_in(controlnet_conditioning)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
res_stack = []
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
res_stack.append(x)
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
return controlnet_res_stack
@staticmethod
def state_dict_converter():
return WanControlNetModelStateDictConverter()
class WanControlNetModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict
def from_base_model(self, state_dict):
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else:
config = {}
state_dict_ = {}
dtype, device = None, None
for name, param in state_dict.items():
if name.startswith("head."):
continue
state_dict_[name] = param
dtype, device = param.dtype, param.device
for block_id in range(config["num_layers"]):
zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device)
state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone()
state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device)
state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device)
return state_dict_, config

View File

@@ -17,6 +17,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_controlnet import WanControlNetModel
@@ -30,7 +31,8 @@ class WanVideoPipeline(BasePipeline):
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae']
self.controlnet: WanControlNetModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet']
self.height_division_factor = 16
self.width_division_factor = 16
@@ -189,6 +191,11 @@ class WanVideoPipeline(BasePipeline):
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return {"controlnet_conditioning": controlnet_conditioning}
@torch.no_grad()
@@ -212,6 +219,7 @@ class WanVideoPipeline(BasePipeline):
tile_stride=(15, 26),
tea_cache_l1_thresh=None,
tea_cache_model_id="",
controlnet_frames=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -252,6 +260,15 @@ class WanVideoPipeline(BasePipeline):
else:
image_emb = {}
# ControlNet
if self.controlnet is not None and controlnet_frames is not None:
self.load_models_to_device(['vae', 'controlnet'])
controlnet_frames = self.preprocess_images(controlnet_frames)
controlnet_frames = torch.stack(controlnet_frames, dim=2).to(dtype=self.torch_dtype, device=self.device)
controlnet_kwargs = self.prepare_controlnet(controlnet_frames)
else:
controlnet_kwargs = {}
# Extra input
extra_input = self.prepare_extra_input(latents)
@@ -260,14 +277,24 @@ class WanVideoPipeline(BasePipeline):
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(["dit"])
self.load_models_to_device(["dit", "controlnet"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi)
noise_pred_posi = model_fn_wan_video(
self.dit, controlnet=self.controlnet,
x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **controlnet_kwargs
)
if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega)
noise_pred_nega = model_fn_wan_video(
self.dit, controlnet=self.controlnet,
x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **controlnet_kwargs
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -340,14 +367,29 @@ class TeaCache:
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
controlnet: WanControlNetModel = None,
x: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
controlnet_conditioning: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
# ControlNet
if controlnet is not None and controlnet_conditioning is not None:
controlnet_res_stack = controlnet(
x, timestep=timestep, context=context, clip_feature=clip_feature, y=y,
controlnet_conditioning=controlnet_conditioning,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
else:
controlnet_res_stack = None
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
@@ -370,13 +412,35 @@ def model_fn_wan_video(
tea_cache_update = tea_cache.check(dit, x, t_mod)
else:
tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update:
x = tea_cache.update(x)
else:
# blocks
for block in dit.blocks:
x = block(x, context, t_mod, freqs)
for block_id, block in enumerate(dit.blocks):
if dit.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
if controlnet_res_stack is not None:
x = x + controlnet_res_stack[block_id]
if tea_cache is not None:
tea_cache.store(x)