From 0d81626fe78a311bf44d8dcafe80ed56fb6c6aa7 Mon Sep 17 00:00:00 2001 From: "lzw478614@alibaba-inc.com" Date: Thu, 21 Aug 2025 20:08:49 +0800 Subject: [PATCH] update wan2.2-fun --- diffsynth/models/wan_video_dit.py | 1 + diffsynth/pipelines/wan_video_new.py | 30 ++++++++++++++++------------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index ba28dcb..9420736 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -294,6 +294,7 @@ class WanModel(torch.nn.Module): ): super().__init__() self.dim = dim + self.in_dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9e83282..53df7d9 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -663,25 +663,23 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit): class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( - input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): if control_video is None: return {} pipe.load_models_to_device(self.onload_model_names) control_video = pipe.preprocess_video(control_video) control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] if clip_feature is None or y is None: clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) - y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) - if pipe.dit2 is not None: - y = torch.zeros((1, 20, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) else: - if pipe.dit2 is None: - y = y[:, -16:] + y = y[:, -y_dim:] y = torch.concat([control_latents, y], dim=1) return {"clip_feature": clip_feature, "y": y} @@ -734,13 +732,19 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) - - image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) - vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) - y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - if pipe.dit2 is not None: + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) msk[:, 1:] = 0 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) @@ -1061,7 +1065,7 @@ def model_fn_wan_video( if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - + # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input)