This commit is contained in:
CD22104
2025-06-11 17:24:09 +08:00
parent 6e977e1181
commit b1afff1728
133 changed files with 954 additions and 9 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -24,7 +24,6 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
from ..lora import GeneralLoRALoader
class BasePipeline(torch.nn.Module):
def __init__(
@@ -208,6 +207,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_FunCamera(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_SpeedControl(),
@@ -473,6 +473,8 @@ class WanVideoPipeline(BasePipeline):
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
# Camera control
control_camera_video: Optional[torch.Tensor] = None
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -491,6 +493,7 @@ class WanVideoPipeline(BasePipeline):
"end_image": end_image,
"input_video": input_video, "denoising_strength": denoising_strength,
"control_video": control_video, "reference_image": reference_image,
"control_camera_video": control_camera_video,
"vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
@@ -653,15 +656,17 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
class WanVideoUnit_ImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "control_camera_video","latents"),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride, control_camera_video,latents):
if input_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
@@ -673,14 +678,13 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
@@ -730,6 +734,37 @@ class WanVideoUnit_FunReference(PipelineUnit):
clip_feature = pipe.image_encoder.encode_image([clip_feature])
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
class WanVideoUnit_FunCamera(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_camera_video", "cfg_merge", "num_frames", "height", "width", "input_image", "latents"),
onload_model_names=("vae")
)
def process(self, pipe: WanVideoPipeline, control_camera_video, cfg_merge, num_frames, height, width, input_image, latents):
if control_camera_video is None:
return {}
control_camera_video = control_camera_video[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
control_camera_latents = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
b, f, c, h, w = control_camera_latents.shape
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
input_image = input_image.resize((width, height))
input_latents = pipe.preprocess_video([input_image])
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
y = torch.zeros_like(latents).to(pipe.device)
if latents.size()[2] != 1:
y[:, :, :1] = input_latents
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"control_camera_latents": control_camera_latents, "control_camera_latents_input": control_camera_latents_input, "y":y}
class WanVideoUnit_SpeedControl(PipelineUnit):
@@ -954,6 +989,8 @@ def model_fn_wan_video(
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
control_camera_latents = None,
control_camera_latents_input = None,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1000,13 +1037,14 @@ def model_fn_wan_video(
x = torch.concat([x] * context.shape[0], dim=0)
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
if dit.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = dit.patchify(x)
# Add camera control
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
# Reference image
if reference_latents is not None: