update wan2.2-fun

This commit is contained in:
lzw478614@alibaba-inc.com
2025-08-21 20:08:49 +08:00
parent 46bd05b54d
commit 0d81626fe7
2 changed files with 18 additions and 13 deletions

View File

@@ -294,6 +294,7 @@ class WanModel(torch.nn.Module):
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.in_dim = dim
self.freq_dim = freq_dim self.freq_dim = freq_dim
self.has_image_input = has_image_input self.has_image_input = has_image_input
self.patch_size = patch_size self.patch_size = patch_size

View File

@@ -663,25 +663,23 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
class WanVideoUnit_FunControl(PipelineUnit): class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"),
onload_model_names=("vae",) onload_model_names=("vae",)
) )
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents):
if control_video is None: if control_video is None:
return {} return {}
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
control_video = pipe.preprocess_video(control_video) control_video = pipe.preprocess_video(control_video)
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1]
if clip_feature is None or y is None: if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
if pipe.dit2 is not None:
y = torch.zeros((1, 20, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
else: else:
if pipe.dit2 is None: y = y[:, -y_dim:]
y = y[:, -16:]
y = torch.concat([control_latents, y], dim=1) y = torch.concat([control_latents, y], dim=1)
return {"clip_feature": clip_feature, "y": y} return {"clip_feature": clip_feature, "y": y}
@@ -735,12 +733,18 @@ class WanVideoUnit_FunCameraControl(PipelineUnit):
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) input_image = input_image.resize((width, height))
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) input_latents = pipe.preprocess_video([input_image])
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] input_latents = pipe.vae.encode(input_latents, device=pipe.device)
y = torch.zeros_like(latents).to(pipe.device)
y[:, :, :1] = input_latents
y = y.to(dtype=pipe.torch_dtype, device=pipe.device) y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
if pipe.dit2 is not None: if y.shape[1] != pipe.dit.in_dim - latents.shape[1]:
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0 msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)