mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
wan-refactor
This commit is contained in:
@@ -2,8 +2,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from packaging import version as pver
|
||||
import os
|
||||
from typing_extensions import Literal
|
||||
|
||||
class SimpleAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
|
||||
super(SimpleAdapter, self).__init__()
|
||||
@@ -42,6 +43,22 @@ class SimpleAdapter(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
def process_camera_coordinates(
|
||||
self,
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
length: int,
|
||||
height: int,
|
||||
width: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
if origin is None:
|
||||
origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
coordinates = generate_camera_coordinates(direction, length, speed, origin)
|
||||
plucker_embedding = process_pose_file(coordinates, width, height)
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
@@ -90,13 +107,8 @@ def get_relative_pose(cam_params):
|
||||
return ret_poses
|
||||
|
||||
def custom_meshgrid(*args):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
||||
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
||||
return torch.meshgrid(*args)
|
||||
else:
|
||||
return torch.meshgrid(*args, indexing='ij')
|
||||
# torch>=2.0.0 only
|
||||
return torch.meshgrid(*args, indexing='ij')
|
||||
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
@@ -128,23 +140,14 @@ def ray_condition(K, c2w, H, W, device):
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.cross(rays_o, rays_d)
|
||||
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
if os.path.isfile(pose_file_path):
|
||||
with open(pose_file_path, 'r') as f:
|
||||
poses = f.readlines()
|
||||
else:
|
||||
poses = pose_file_path.splitlines()
|
||||
|
||||
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
||||
cam_params = [[float(x) for x in pose] for pose in poses]
|
||||
def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||
if return_poses:
|
||||
return cam_params
|
||||
else:
|
||||
@@ -175,3 +178,25 @@ def process_pose_file(pose_file_path, width=672, height=384, original_pose_width
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
def generate_camera_coordinates(
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
length: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
coordinates = [list(origin)]
|
||||
while len(coordinates) < length:
|
||||
coor = coordinates[-1].copy()
|
||||
if "Left" in direction:
|
||||
coor[9] += speed
|
||||
if "Right" in direction:
|
||||
coor[9] -= speed
|
||||
if "Up" in direction:
|
||||
coor[13] += speed
|
||||
if "Down" in direction:
|
||||
coor[13] -= speed
|
||||
coordinates.append(coor)
|
||||
return coordinates
|
||||
|
||||
Reference in New Issue
Block a user