From 2dad9a319c49c865b0cfb211372c55d3a770403c Mon Sep 17 00:00:00 2001 From: "lzw478614@alibaba-inc.com" Date: Wed, 20 Aug 2025 20:17:41 +0800 Subject: [PATCH] update wan2.2-fun --- diffsynth/configs/model_config.py | 3 ++- diffsynth/models/wan_video_dit.py | 36 ++++++++++++++++++++++++++++ diffsynth/pipelines/wan_video_new.py | 31 +++++++++++++++++------- 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 9de5bfb..4731fee 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -150,6 +150,8 @@ model_loader_configs = [ (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), + (None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"), + (None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), @@ -169,7 +171,6 @@ model_loader_configs = [ (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), (None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), - (None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 419f8cf..ba28dcb 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -713,6 +713,42 @@ class WanModelStateDictConverter: "eps": 1e-6, "require_clip_embedding": False, } + elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5": + # Wan2.2-Fun-A14B-Control + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 52, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": True, + "require_clip_embedding": False, + } + elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1": + # Wan2.2-Fun-A14B-Control-Camera + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": False, + "add_control_adapter": True, + "in_dim_control_adapter": 24, + "require_clip_embedding": False, + } else: config = {} return state_dict, config diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 89adbdf..9e83282 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -677,8 +677,11 @@ class WanVideoUnit_FunControl(PipelineUnit): if clip_feature is None or y is None: clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + if pipe.dit2 is not None: + y = torch.zeros((1, 20, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) else: - y = y[:, -16:] + if pipe.dit2 is None: + y = y[:, -16:] y = torch.concat([control_latents, y], dim=1) return {"clip_feature": clip_feature, "y": y} @@ -698,6 +701,8 @@ class WanVideoUnit_FunReference(PipelineUnit): reference_image = reference_image.resize((width, height)) reference_latents = pipe.preprocess_video([reference_image]) reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} clip_feature = pipe.preprocess_image(reference_image) clip_feature = pipe.image_encoder.encode_image([clip_feature]) return {"reference_latents": reference_latents, "clip_feature": clip_feature} @@ -707,13 +712,14 @@ class WanVideoUnit_FunReference(PipelineUnit): class WanVideoUnit_FunCameraControl(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"), + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image): + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): if camera_control_direction is None: return {} + pipe.load_models_to_device(self.onload_model_names) camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) @@ -729,13 +735,20 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) - input_image = input_image.resize((width, height)) - input_latents = pipe.preprocess_video([input_image]) - pipe.load_models_to_device(self.onload_model_names) - input_latents = pipe.vae.encode(input_latents, device=pipe.device) - y = torch.zeros_like(latents).to(pipe.device) - y[:, :, :1] = input_latents + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if pipe.dit2 is not None: + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"control_camera_latents_input": control_camera_latents_input, "y": y}