mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
wan-refactor
This commit is contained in:
@@ -10,6 +10,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
@@ -208,9 +209,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedder(),
|
||||
WanVideoUnit_FunCamera(),
|
||||
WanVideoUnit_FunControl(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
WanVideoUnit_SpeedControl(),
|
||||
WanVideoUnit_VACE(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
@@ -472,6 +473,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
# ControlNet
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
# Camera control
|
||||
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
|
||||
camera_control_speed: Optional[float] = 1/54,
|
||||
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
|
||||
# VACE
|
||||
vace_video: Optional[list[Image.Image]] = None,
|
||||
vace_video_mask: Optional[Image.Image] = None,
|
||||
@@ -504,8 +509,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
tea_cache_model_id: Optional[str] = "",
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
# Camera control
|
||||
control_camera_video: Optional[torch.Tensor] = None
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
@@ -524,7 +527,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"end_image": end_image,
|
||||
"input_video": input_video, "denoising_strength": denoising_strength,
|
||||
"control_video": control_video, "reference_image": reference_image,
|
||||
"control_camera_video": control_camera_video,
|
||||
"camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin,
|
||||
"vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
@@ -724,37 +727,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
class WanVideoUnit_FunCamera(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("control_camera_video", "cfg_merge", "num_frames", "height", "width", "input_image", "latents"),
|
||||
onload_model_names=("vae")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, control_camera_video, cfg_merge, num_frames, height, width, input_image, latents):
|
||||
if control_camera_video is None:
|
||||
return {}
|
||||
control_camera_video = control_camera_video[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
|
||||
control_camera_latents = torch.concat(
|
||||
[
|
||||
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
||||
control_camera_video[:, :, 1:]
|
||||
], dim=2
|
||||
).transpose(1, 2)
|
||||
b, f, c, h, w = control_camera_latents.shape
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
|
||||
input_image = input_image.resize((width, height))
|
||||
input_latents = pipe.preprocess_video([input_image])
|
||||
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
||||
y = torch.zeros_like(latents).to(pipe.device)
|
||||
if latents.size()[2] != 1:
|
||||
y[:, :, :1] = input_latents
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
return {"control_camera_latents": control_camera_latents, "control_camera_latents_input": control_camera_latents_input, "y":y}
|
||||
|
||||
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
@@ -800,6 +773,40 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
|
||||
if camera_control_direction is None:
|
||||
return {}
|
||||
camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
|
||||
camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
|
||||
|
||||
control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
|
||||
control_camera_latents = torch.concat(
|
||||
[
|
||||
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
||||
control_camera_video[:, :, 1:]
|
||||
], dim=2
|
||||
).transpose(1, 2)
|
||||
b, f, c, h, w = control_camera_latents.shape
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
|
||||
input_image = input_image.resize((width, height))
|
||||
input_latents = pipe.preprocess_video([input_image])
|
||||
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
||||
y = torch.zeros_like(latents).to(pipe.device)
|
||||
y[:, :, :1] = input_latents
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_SpeedControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("motion_bucket_id",))
|
||||
|
||||
Reference in New Issue
Block a user