import torch import torch.nn as nn import numpy as np from einops import rearrange from packaging import version as pver import os class SimpleAdapter(nn.Module): def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): super(SimpleAdapter, self).__init__() # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) # Convolution: reduce spatial dimensions by a factor # of 2 (without overlap) self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) # Residual blocks for feature extraction self.residual_blocks = nn.Sequential( *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] ) def forward(self, x): # Reshape to merge the frame dimension into batch bs, c, f, h, w = x.size() x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) # Pixel Unshuffle operation x_unshuffled = self.pixel_unshuffle(x) # Convolution operation x_conv = self.conv(x_unshuffled) # Feature extraction with residual blocks out = self.residual_blocks(x_conv) # Reshape to restore original bf dimension out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames out = out.permute(0, 2, 1, 3, 4) return out class ResidualBlock(nn.Module): def __init__(self, dim): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) def forward(self, x): residual = x out = self.relu(self.conv1(x)) out = self.conv2(out) out += residual return out class Camera(object): """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py """ def __init__(self, entry): fx, fy, cx, cy = entry[1:5] self.fx = fx self.fy = fy self.cx = cx self.cy = cy w2c_mat = np.array(entry[7:]).reshape(3, 4) w2c_mat_4x4 = np.eye(4) w2c_mat_4x4[:3, :] = w2c_mat self.w2c_mat = w2c_mat_4x4 self.c2w_mat = np.linalg.inv(w2c_mat_4x4) def get_relative_pose(cam_params): """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py """ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] cam_to_origin = 0 target_cam_c2w = np.array([ [1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1] ]) abs2rel = target_cam_c2w @ abs_w2cs[0] ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] ret_poses = np.array(ret_poses, dtype=np.float32) return ret_poses def custom_meshgrid(*args): """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py """ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse('1.10'): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing='ij') def ray_condition(K, c2w, H, W, device): """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py """ # c2w: B, V, 4, 4 # K: B, V, 4 B = K.shape[0] j, i = custom_meshgrid( torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), ) i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 zs = torch.ones_like(i) # [B, HxW] xs = (i - cx) / fx * zs ys = (j - cy) / fy * zs zs = zs.expand_as(ys) directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW rays_o = c2w[..., :3, 3] # B, V, 3 rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW # c2w @ dirctions rays_dxo = torch.cross(rays_o, rays_d) plucker = torch.cat([rays_dxo, rays_d], dim=-1) plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 # plucker = plucker.permute(0, 1, 4, 2, 3) return plucker def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py """ if os.path.isfile(pose_file_path): with open(pose_file_path, 'r') as f: poses = f.readlines() else: poses = pose_file_path.splitlines() poses = [pose.strip().split(' ') for pose in poses[1:]] cam_params = [[float(x) for x in pose] for pose in poses] if return_poses: return cam_params else: cam_params = [Camera(cam_param) for cam_param in cam_params] sample_wh_ratio = width / height pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed if pose_wh_ratio > sample_wh_ratio: resized_ori_w = height * pose_wh_ratio for cam_param in cam_params: cam_param.fx = resized_ori_w * cam_param.fx / width else: resized_ori_h = width / pose_wh_ratio for cam_param in cam_params: cam_param.fy = resized_ori_h * cam_param.fy / height intrinsic = np.asarray([[cam_param.fx * width, cam_param.fy * height, cam_param.cx * width, cam_param.cy * height] for cam_param in cam_params], dtype=np.float32) K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W plucker_embedding = plucker_embedding[None] plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] return plucker_embedding