mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan2.2 A14B I2V&T2V
This commit is contained in:
@@ -141,6 +141,8 @@ model_loader_configs = [
|
||||
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
|
||||
@@ -352,6 +352,7 @@ class WanModel(torch.nn.Module):
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
fused_y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
@@ -365,6 +366,8 @@ class WanModel(torch.nn.Module):
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
if fused_y is not None:
|
||||
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
@@ -673,6 +676,7 @@ class WanModelStateDictConverter:
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
|
||||
# Wan-AI/Wan2.2-TI2V-5B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
@@ -687,6 +691,21 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6,
|
||||
"seperated_timestep": True,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||
# Wan-AI/Wan2.2-I2V-A14B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -226,10 +226,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.dit2: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||
self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
@@ -238,6 +239,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedder(),
|
||||
WanVideoUnit_ImageVaeEmbedder(),
|
||||
WanVideoUnit_ImageEmbedderNoClip(),
|
||||
WanVideoUnit_FunControl(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
@@ -329,6 +331,37 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit2 is not None:
|
||||
dtype = next(iter(self.dit2.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
@@ -427,6 +460,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
for block in self.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
||||
if self.dit2 is not None:
|
||||
for block in self.dit2.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
@@ -473,6 +510,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Load models
|
||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||
num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"])
|
||||
if num_dits == 2:
|
||||
pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1]
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
@@ -523,6 +563,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Boundary
|
||||
boundary: Optional[float] = 0.875,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
@@ -575,8 +617,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
# switch high_noise DiT to low_noise DiT
|
||||
if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps:
|
||||
print("switching to low noise DiT")
|
||||
self.load_models_to_device(["dit2", "motion_controller", "vace"])
|
||||
models["dit"] = models.pop("dit2")
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
if cfg_scale != 1.0:
|
||||
@@ -737,7 +783,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or pipe.dit.seperated_timestep:
|
||||
if input_image is None or pipe.image_encoder is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
@@ -767,6 +813,9 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
|
||||
"""
|
||||
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
@@ -815,6 +864,42 @@ class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
|
||||
return out1, out2
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit):
|
||||
"""
|
||||
Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"fused_y": y}
|
||||
|
||||
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1116,6 +1201,7 @@ def model_fn_wan_video(
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
fused_y: Optional[torch.Tensor] = None,
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
@@ -1181,11 +1267,13 @@ def model_fn_wan_video(
|
||||
x = torch.concat([x] * context.shape[0], dim=0)
|
||||
if timestep.shape[0] != context.shape[0]:
|
||||
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
||||
|
||||
|
||||
if dit.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
if fused_y is not None:
|
||||
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
|
||||
|
||||
# Add camera control
|
||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480))
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
27
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
27
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
|
||||
snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B")
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
@@ -1,17 +1,15 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
from diffsynth.models.utils import load_state_dict, hash_state_dict_keys
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="model_shards/model-*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user