mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
wan-fun-v1.1 reference control
This commit is contained in:
@@ -131,6 +131,8 @@ model_loader_configs = [
|
|||||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||||
|
|||||||
@@ -272,6 +272,7 @@ class WanModel(torch.nn.Module):
|
|||||||
num_layers: int,
|
num_layers: int,
|
||||||
has_image_input: bool,
|
has_image_input: bool,
|
||||||
has_image_pos_emb: bool = False,
|
has_image_pos_emb: bool = False,
|
||||||
|
has_ref_conv: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -303,7 +304,10 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
if has_image_input:
|
if has_image_input:
|
||||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||||
|
if has_ref_conv:
|
||||||
|
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
||||||
self.has_image_pos_emb = has_image_pos_emb
|
self.has_image_pos_emb = has_image_pos_emb
|
||||||
|
self.has_ref_conv = has_ref_conv
|
||||||
|
|
||||||
def patchify(self, x: torch.Tensor):
|
def patchify(self, x: torch.Tensor):
|
||||||
x = self.patch_embedding(x)
|
x = self.patch_embedding(x)
|
||||||
@@ -532,6 +536,7 @@ class WanModelStateDictConverter:
|
|||||||
"eps": 1e-6
|
"eps": 1e-6
|
||||||
}
|
}
|
||||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||||
|
# 1.3B PAI control
|
||||||
config = {
|
config = {
|
||||||
"has_image_input": True,
|
"has_image_input": True,
|
||||||
"patch_size": [1, 2, 2],
|
"patch_size": [1, 2, 2],
|
||||||
@@ -546,6 +551,7 @@ class WanModelStateDictConverter:
|
|||||||
"eps": 1e-6
|
"eps": 1e-6
|
||||||
}
|
}
|
||||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||||
|
# 14B PAI control
|
||||||
config = {
|
config = {
|
||||||
"has_image_input": True,
|
"has_image_input": True,
|
||||||
"patch_size": [1, 2, 2],
|
"patch_size": [1, 2, 2],
|
||||||
@@ -574,6 +580,38 @@ class WanModelStateDictConverter:
|
|||||||
"eps": 1e-6,
|
"eps": 1e-6,
|
||||||
"has_image_pos_emb": True
|
"has_image_pos_emb": True
|
||||||
}
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||||
|
# 1.3B PAI control v1.1
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 48,
|
||||||
|
"dim": 1536,
|
||||||
|
"ffn_dim": 8960,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"has_ref_conv": True
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||||
|
# 14B PAI control v1.1
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 48,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"has_ref_conv": True
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
return state_dict, config
|
return state_dict, config
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
torch.nn.Conv3d: AutoWrappedModule,
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
RMSNorm: AutoWrappedModule,
|
RMSNorm: AutoWrappedModule,
|
||||||
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
},
|
},
|
||||||
module_config = dict(
|
module_config = dict(
|
||||||
offload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
@@ -237,6 +238,18 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_reference_image(self, reference_image, height, width):
|
||||||
|
if reference_image is not None:
|
||||||
|
self.load_models_to_device(["vae"])
|
||||||
|
reference_image = reference_image.resize((width, height))
|
||||||
|
reference_image = self.preprocess_images([reference_image])
|
||||||
|
reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
reference_latents = self.vae.encode(reference_image, device=self.device)
|
||||||
|
return {"reference_latents": reference_latents}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
if control_video is not None:
|
if control_video is not None:
|
||||||
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
@@ -339,6 +352,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
end_image=None,
|
end_image=None,
|
||||||
input_video=None,
|
input_video=None,
|
||||||
control_video=None,
|
control_video=None,
|
||||||
|
reference_image=None,
|
||||||
vace_video=None,
|
vace_video=None,
|
||||||
vace_video_mask=None,
|
vace_video_mask=None,
|
||||||
vace_reference_image=None,
|
vace_reference_image=None,
|
||||||
@@ -398,6 +412,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
else:
|
else:
|
||||||
image_emb = {}
|
image_emb = {}
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
reference_image_kwargs = self.prepare_reference_image(reference_image, height, width)
|
||||||
|
|
||||||
# ControlNet
|
# ControlNet
|
||||||
if control_video is not None:
|
if control_video is not None:
|
||||||
self.load_models_to_device(["image_encoder", "vae"])
|
self.load_models_to_device(["image_encoder", "vae"])
|
||||||
@@ -435,14 +452,14 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||||
x=latents, timestep=timestep,
|
x=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **image_emb, **extra_input,
|
**prompt_emb_posi, **image_emb, **extra_input,
|
||||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||||
)
|
)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = model_fn_wan_video(
|
noise_pred_nega = model_fn_wan_video(
|
||||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||||
x=latents, timestep=timestep,
|
x=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **image_emb, **extra_input,
|
**prompt_emb_nega, **image_emb, **extra_input,
|
||||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -526,6 +543,7 @@ def model_fn_wan_video(
|
|||||||
context: torch.Tensor = None,
|
context: torch.Tensor = None,
|
||||||
clip_feature: Optional[torch.Tensor] = None,
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
|
reference_latents = None,
|
||||||
vace_context = None,
|
vace_context = None,
|
||||||
vace_scale = 1.0,
|
vace_scale = 1.0,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
@@ -552,6 +570,12 @@ def model_fn_wan_video(
|
|||||||
|
|
||||||
x, (f, h, w) = dit.patchify(x)
|
x, (f, h, w) = dit.patchify(x)
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
if reference_latents is not None:
|
||||||
|
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
|
||||||
|
x = torch.concat([reference_latents, x], dim=1)
|
||||||
|
f += 1
|
||||||
|
|
||||||
freqs = torch.cat([
|
freqs = torch.cat([
|
||||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
@@ -580,6 +604,10 @@ def model_fn_wan_video(
|
|||||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
|
if reference_latents is not None:
|
||||||
|
x = x[:, reference_latents.shape[1]:]
|
||||||
|
f -= 1
|
||||||
|
|
||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
|
|||||||
35
examples/wanvideo/wan_fun_reference_control.py
Normal file
35
examples/wanvideo/wan_fun_reference_control.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||||
|
from modelscope import snapshot_download, dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
# snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-V1.1-1.3B-Control")
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/PAI/Wan2.1-Fun-V1.1-14B-Control/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/PAI/Wan2.1-Fun-V1.1-14B-Control/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/PAI/Wan2.1-Fun-V1.1-14B-Control/Wan2.1_VAE.pth",
|
||||||
|
"models/PAI/Wan2.1-Fun-V1.1-14B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||||
|
)
|
||||||
|
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||||
|
|
||||||
|
# Control-to-video
|
||||||
|
control_video = VideoData("xxx/pose.mp4", height=832, width=480)
|
||||||
|
control_video = [control_video[i] for i in range(49)]
|
||||||
|
video = pipe(
|
||||||
|
prompt="一位年轻女性穿着一件粉色的连衣裙,裙子上有白色的装饰和粉色的纽扣。她的头发是紫色的,头上戴着一个红色的大蝴蝶结,显得非常可爱和精致。她还戴着一个红色的领结,整体造型充满了少女感和活力。她的表情温柔,双手轻轻交叉放在身前,姿态优雅。背景是简单的灰色,没有任何多余的装饰,使得人物更加突出。她的妆容清淡自然,突显了她的清新气质。整体画面给人一种甜美、梦幻的感觉,仿佛置身于童话世界中。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
reference_image=Image.open("xxx/6.png").convert("RGB").resize((480, 832)),
|
||||||
|
control_video=control_video, height=832, width=480, num_frames=49,
|
||||||
|
seed=1, tiled=True
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
Reference in New Issue
Block a user