wan-refactor

This commit is contained in:
Artiprocher
2025-06-13 13:04:35 +08:00
parent 7e6a3c7897
commit 8dd24169cc
47 changed files with 810 additions and 310 deletions

View File

@@ -2,8 +2,9 @@ import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from packaging import version as pver
import os
from typing_extensions import Literal
class SimpleAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
super(SimpleAdapter, self).__init__()
@@ -42,6 +43,22 @@ class SimpleAdapter(nn.Module):
return out
def process_camera_coordinates(
self,
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
length: int,
height: int,
width: int,
speed: float = 1/54,
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
):
if origin is None:
origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
coordinates = generate_camera_coordinates(direction, length, speed, origin)
plucker_embedding = process_pose_file(coordinates, width, height)
return plucker_embedding
class ResidualBlock(nn.Module):
def __init__(self, dim):
@@ -90,13 +107,8 @@ def get_relative_pose(cam_params):
return ret_poses
def custom_meshgrid(*args):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
# torch>=2.0.0 only
return torch.meshgrid(*args, indexing='ij')
def ray_condition(K, c2w, H, W, device):
@@ -128,23 +140,14 @@ def ray_condition(K, c2w, H, W, device):
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d)
rays_dxo = torch.linalg.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
if os.path.isfile(pose_file_path):
with open(pose_file_path, 'r') as f:
poses = f.readlines()
else:
poses = pose_file_path.splitlines()
poses = [pose.strip().split(' ') for pose in poses[1:]]
cam_params = [[float(x) for x in pose] for pose in poses]
def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
if return_poses:
return cam_params
else:
@@ -175,3 +178,25 @@ def process_pose_file(pose_file_path, width=672, height=384, original_pose_width
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding
def generate_camera_coordinates(
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
length: int,
speed: float = 1/54,
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
):
coordinates = [list(origin)]
while len(coordinates) < length:
coor = coordinates[-1].copy()
if "Left" in direction:
coor[9] += speed
if "Right" in direction:
coor[9] -= speed
if "Up" in direction:
coor[13] += speed
if "Down" in direction:
coor[13] -= speed
coordinates.append(coor)
return coordinates

View File

@@ -10,6 +10,7 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
@@ -208,9 +209,9 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_FunCamera(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_FunCameraControl(),
WanVideoUnit_SpeedControl(),
WanVideoUnit_VACE(),
WanVideoUnit_UnifiedSequenceParallel(),
@@ -472,6 +473,10 @@ class WanVideoPipeline(BasePipeline):
# ControlNet
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
# Camera control
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
camera_control_speed: Optional[float] = 1/54,
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
# VACE
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
@@ -504,8 +509,6 @@ class WanVideoPipeline(BasePipeline):
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
# Camera control
control_camera_video: Optional[torch.Tensor] = None
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -524,7 +527,7 @@ class WanVideoPipeline(BasePipeline):
"end_image": end_image,
"input_video": input_video, "denoising_strength": denoising_strength,
"control_video": control_video, "reference_image": reference_image,
"control_camera_video": control_camera_video,
"camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin,
"vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
@@ -724,37 +727,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_FunCamera(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_camera_video", "cfg_merge", "num_frames", "height", "width", "input_image", "latents"),
onload_model_names=("vae")
)
def process(self, pipe: WanVideoPipeline, control_camera_video, cfg_merge, num_frames, height, width, input_image, latents):
if control_camera_video is None:
return {}
control_camera_video = control_camera_video[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
control_camera_latents = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
b, f, c, h, w = control_camera_latents.shape
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
input_image = input_image.resize((width, height))
input_latents = pipe.preprocess_video([input_image])
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
y = torch.zeros_like(latents).to(pipe.device)
if latents.size()[2] != 1:
y[:, :, :1] = input_latents
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"control_camera_latents": control_camera_latents, "control_camera_latents_input": control_camera_latents_input, "y":y}
class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self):
@@ -800,6 +773,40 @@ class WanVideoUnit_FunReference(PipelineUnit):
class WanVideoUnit_FunCameraControl(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image")
)
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
if camera_control_direction is None:
return {}
camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
control_camera_latents = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
b, f, c, h, w = control_camera_latents.shape
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
input_image = input_image.resize((width, height))
input_latents = pipe.preprocess_video([input_image])
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
y = torch.zeros_like(latents).to(pipe.device)
y[:, :, :1] = input_latents
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
class WanVideoUnit_SpeedControl(PipelineUnit):
def __init__(self):
super().__init__(input_params=("motion_bucket_id",))

View File

@@ -228,30 +228,24 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat
def wan_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.")
parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in each video. The frames are sampled from the prefix.")
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.")
parser.add_argument("--model_paths", type=str, default=None, help="Model paths to be loaded. JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas.")
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.")
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--output_path", type=str, default="./models", help="Save path.")
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--trainable_models", type=str, default=None, help="Trainable models, e.g., dit, vae, text_encoder.")
parser.add_argument("--lora_base_model", type=str, default=None, help="Add LoRA on which model.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Add LoRA on which layer.")
parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank.")
parser.add_argument("--input_contains_input_image", default=False, action="store_true", help="Model input contains 'input_image'.")
parser.add_argument("--input_contains_end_image", default=False, action="store_true", help="Model input contains 'end_image'.")
parser.add_argument("--input_contains_control_video", default=False, action="store_true", help="Model input contains 'control_video'.")
parser.add_argument("--input_contains_reference_image", default=False, action="store_true", help="Model input contains 'reference_image'.")
parser.add_argument("--input_contains_vace_video", default=False, action="store_true", help="Model input contains 'vace_video'.")
parser.add_argument("--input_contains_vace_reference_image", default=False, action="store_true", help="Model input contains 'vace_reference_image'.")
parser.add_argument("--input_contains_motion_bucket_id", default=False, action="store_true", help="Model input contains 'motion_bucket_id'.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Offload gradient checkpointing to RAM.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
return parser