This commit is contained in:
CD22104
2025-06-11 17:24:09 +08:00
parent 6e977e1181
commit b1afff1728
133 changed files with 954 additions and 9 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -133,6 +133,8 @@ model_loader_configs = [
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -5,6 +5,10 @@ import math
from typing import Tuple, Optional
from einops import rearrange
from .utils import hash_state_dict_keys
from dchen.camera_adapter import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
@@ -273,6 +277,8 @@ class WanModel(torch.nn.Module):
has_image_input: bool,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
):
super().__init__()
self.dim = dim
@@ -309,8 +315,17 @@ class WanModel(torch.nn.Module):
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
def patchify(self, x: torch.Tensor):
if add_control_adapter:
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]
x = x[0].unsqueeze(0)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
@@ -612,6 +627,42 @@ class WanModelStateDictConverter:
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
# 1.3B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
# 14B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
else:
config = {}
return state_dict, config

Binary file not shown.

Binary file not shown.

View File

@@ -24,7 +24,6 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
from ..lora import GeneralLoRALoader
class BasePipeline(torch.nn.Module):
def __init__(
@@ -208,6 +207,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_FunCamera(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_SpeedControl(),
@@ -473,6 +473,8 @@ class WanVideoPipeline(BasePipeline):
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
# Camera control
control_camera_video: Optional[torch.Tensor] = None
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -491,6 +493,7 @@ class WanVideoPipeline(BasePipeline):
"end_image": end_image,
"input_video": input_video, "denoising_strength": denoising_strength,
"control_video": control_video, "reference_image": reference_image,
"control_camera_video": control_camera_video,
"vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
@@ -653,15 +656,17 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
class WanVideoUnit_ImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "control_camera_video","latents"),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride, control_camera_video,latents):
if input_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
@@ -673,14 +678,13 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
@@ -730,6 +734,37 @@ class WanVideoUnit_FunReference(PipelineUnit):
clip_feature = pipe.image_encoder.encode_image([clip_feature])
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
class WanVideoUnit_FunCamera(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_camera_video", "cfg_merge", "num_frames", "height", "width", "input_image", "latents"),
onload_model_names=("vae")
)
def process(self, pipe: WanVideoPipeline, control_camera_video, cfg_merge, num_frames, height, width, input_image, latents):
if control_camera_video is None:
return {}
control_camera_video = control_camera_video[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
control_camera_latents = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
b, f, c, h, w = control_camera_latents.shape
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
input_image = input_image.resize((width, height))
input_latents = pipe.preprocess_video([input_image])
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
y = torch.zeros_like(latents).to(pipe.device)
if latents.size()[2] != 1:
y[:, :, :1] = input_latents
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"control_camera_latents": control_camera_latents, "control_camera_latents_input": control_camera_latents_input, "y":y}
class WanVideoUnit_SpeedControl(PipelineUnit):
@@ -954,6 +989,8 @@ def model_fn_wan_video(
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
control_camera_latents = None,
control_camera_latents_input = None,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1000,13 +1037,14 @@ def model_fn_wan_video(
x = torch.concat([x] * context.shape[0], dim=0)
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
if dit.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = dit.patchify(x)
# Add camera control
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
# Reference image
if reference_latents is not None:

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More