Revert "Wan refactor"

This commit is contained in:
Zhongjie Duan
2025-06-11 17:29:27 +08:00
committed by GitHub
parent 8badd63a2d
commit 40760ab88b
216 changed files with 1332 additions and 4567 deletions

BIN
.msc

Binary file not shown.

1
.mv
View File

@@ -1 +0,0 @@
master

Binary file not shown.

Before

Width:  |  Height:  |  Size: 477 KiB

View File

@@ -1,62 +0,0 @@
import torch
import torch.nn as nn
class SimpleAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
super(SimpleAdapter, self).__init__()
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
)
def forward(self, x):
# Reshape to merge the frame dimension into batch
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
# Pixel Unshuffle operation
x_unshuffled = self.pixel_unshuffle(x)
# Convolution operation
x_conv = self.conv(x_unshuffled)
# Feature extraction with residual blocks
out = self.residual_blocks(x_conv)
# Reshape to restore original bf dimension
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
out = out.permute(0, 2, 1, 3, 4)
return out
class ResidualBlock(nn.Module):
def __init__(self, dim):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out
# Example usage
# in_dim = 3
# out_dim = 64
# adapter = SimpleAdapterWithReshape(in_dim, out_dim)
# x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
# output = adapter(x)
# print(output.shape) # Should reflect transformed dimensions

View File

@@ -1,174 +0,0 @@
import csv
import gc
import io
import json
import math
import os
import random
from random import shuffle
import albumentations
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from decord import VideoReader
from einops import rearrange
from packaging import version as pver
from PIL import Image
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data.dataset import Dataset
class Camera(object):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
w2c_mat = np.array(entry[7:]).reshape(3, 4)
w2c_mat_4x4 = np.eye(4)
w2c_mat_4x4[:3, :] = w2c_mat
self.w2c_mat = w2c_mat_4x4
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
def get_relative_pose(cam_params):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
cam_to_origin = 0
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
def ray_condition(K, c2w, H, W, device):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# c2w: B, V, 4, 4
# K: B, V, 4
B = K.shape[0]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
def custom_meshgrid(*args):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
with open(pose_file_path, 'r') as f:
poses = f.readlines()
poses = [pose.strip().split(' ') for pose in poses[1:]]
cam_params = [[float(x) for x in pose] for pose in poses]
if return_poses:
return cam_params
else:
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding

View File

@@ -1,82 +0,0 @@
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.018518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.037037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.05555555555555555 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.07407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.09259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.1111111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.12962962962962962 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.14814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.16666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.18518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.24074074074074073 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.25925925925925924 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2777777777777778 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.31481481481481477 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.35185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.37037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.38888888888888884 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.42592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.48148148148148145 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5185185185185185 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.537037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5740740740740741 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.611111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6296296296296295 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8148148148148148 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8333333333333334 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8703703703703705 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9259259259259258 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9814814814814815 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1111111111111112 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1296296296296298 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1666666666666667 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3148148148148149 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3703703703703702 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.425925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.462962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0

View File

@@ -1,33 +0,0 @@
Metadata-Version: 2.2
Name: diffsynth
Version: 1.1.7
Summary: Enjoy the magic of Diffusion models!
Author: Artiprocher
License: UNKNOWN
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: torchvision
Requires-Dist: cupy-cuda12x
Requires-Dist: transformers
Requires-Dist: controlnet-aux==0.0.7
Requires-Dist: imageio
Requires-Dist: imageio[ffmpeg]
Requires-Dist: safetensors
Requires-Dist: einops
Requires-Dist: sentencepiece
Requires-Dist: protobuf
Requires-Dist: modelscope
Requires-Dist: ftfy
Requires-Dist: pynvml
Dynamic: author
Dynamic: classifier
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary
UNKNOWN

View File

@@ -1,226 +0,0 @@
LICENSE
README.md
setup.py
diffsynth/__init__.py
diffsynth.egg-info/PKG-INFO
diffsynth.egg-info/SOURCES.txt
diffsynth.egg-info/dependency_links.txt
diffsynth.egg-info/requires.txt
diffsynth.egg-info/top_level.txt
diffsynth/configs/__init__.py
diffsynth/configs/model_config.py
diffsynth/controlnets/__init__.py
diffsynth/controlnets/controlnet_unit.py
diffsynth/controlnets/processors.py
diffsynth/data/__init__.py
diffsynth/data/simple_text_image.py
diffsynth/data/video.py
diffsynth/distributed/__init__.py
diffsynth/distributed/xdit_context_parallel.py
diffsynth/extensions/__init__.py
diffsynth/extensions/ESRGAN/__init__.py
diffsynth/extensions/FastBlend/__init__.py
diffsynth/extensions/FastBlend/api.py
diffsynth/extensions/FastBlend/cupy_kernels.py
diffsynth/extensions/FastBlend/data.py
diffsynth/extensions/FastBlend/patch_match.py
diffsynth/extensions/FastBlend/runners/__init__.py
diffsynth/extensions/FastBlend/runners/accurate.py
diffsynth/extensions/FastBlend/runners/balanced.py
diffsynth/extensions/FastBlend/runners/fast.py
diffsynth/extensions/FastBlend/runners/interpolation.py
diffsynth/extensions/ImageQualityMetric/__init__.py
diffsynth/extensions/ImageQualityMetric/aesthetic.py
diffsynth/extensions/ImageQualityMetric/clip.py
diffsynth/extensions/ImageQualityMetric/config.py
diffsynth/extensions/ImageQualityMetric/hps.py
diffsynth/extensions/ImageQualityMetric/imagereward.py
diffsynth/extensions/ImageQualityMetric/mps.py
diffsynth/extensions/ImageQualityMetric/pickscore.py
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py
diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
diffsynth/extensions/ImageQualityMetric/open_clip/version.py
diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py
diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py
diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py
diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py
diffsynth/extensions/RIFE/__init__.py
diffsynth/lora/__init__.py
diffsynth/models/__init__.py
diffsynth/models/attention.py
diffsynth/models/cog_dit.py
diffsynth/models/cog_vae.py
diffsynth/models/downloader.py
diffsynth/models/flux_controlnet.py
diffsynth/models/flux_dit.py
diffsynth/models/flux_infiniteyou.py
diffsynth/models/flux_ipadapter.py
diffsynth/models/flux_text_encoder.py
diffsynth/models/flux_vae.py
diffsynth/models/hunyuan_dit.py
diffsynth/models/hunyuan_dit_text_encoder.py
diffsynth/models/hunyuan_video_dit.py
diffsynth/models/hunyuan_video_text_encoder.py
diffsynth/models/hunyuan_video_vae_decoder.py
diffsynth/models/hunyuan_video_vae_encoder.py
diffsynth/models/kolors_text_encoder.py
diffsynth/models/lora.py
diffsynth/models/model_manager.py
diffsynth/models/omnigen.py
diffsynth/models/qwenvl.py
diffsynth/models/sd3_dit.py
diffsynth/models/sd3_text_encoder.py
diffsynth/models/sd3_vae_decoder.py
diffsynth/models/sd3_vae_encoder.py
diffsynth/models/sd_controlnet.py
diffsynth/models/sd_ipadapter.py
diffsynth/models/sd_motion.py
diffsynth/models/sd_text_encoder.py
diffsynth/models/sd_unet.py
diffsynth/models/sd_vae_decoder.py
diffsynth/models/sd_vae_encoder.py
diffsynth/models/sdxl_controlnet.py
diffsynth/models/sdxl_ipadapter.py
diffsynth/models/sdxl_motion.py
diffsynth/models/sdxl_text_encoder.py
diffsynth/models/sdxl_unet.py
diffsynth/models/sdxl_vae_decoder.py
diffsynth/models/sdxl_vae_encoder.py
diffsynth/models/step1x_connector.py
diffsynth/models/stepvideo_dit.py
diffsynth/models/stepvideo_text_encoder.py
diffsynth/models/stepvideo_vae.py
diffsynth/models/svd_image_encoder.py
diffsynth/models/svd_unet.py
diffsynth/models/svd_vae_decoder.py
diffsynth/models/svd_vae_encoder.py
diffsynth/models/tiler.py
diffsynth/models/utils.py
diffsynth/models/wan_video_dit.py
diffsynth/models/wan_video_image_encoder.py
diffsynth/models/wan_video_motion_controller.py
diffsynth/models/wan_video_text_encoder.py
diffsynth/models/wan_video_vace.py
diffsynth/models/wan_video_vae.py
diffsynth/pipelines/__init__.py
diffsynth/pipelines/base.py
diffsynth/pipelines/cog_video.py
diffsynth/pipelines/dancer.py
diffsynth/pipelines/flux_image.py
diffsynth/pipelines/hunyuan_image.py
diffsynth/pipelines/hunyuan_video.py
diffsynth/pipelines/omnigen_image.py
diffsynth/pipelines/pipeline_runner.py
diffsynth/pipelines/sd3_image.py
diffsynth/pipelines/sd_image.py
diffsynth/pipelines/sd_video.py
diffsynth/pipelines/sdxl_image.py
diffsynth/pipelines/sdxl_video.py
diffsynth/pipelines/step_video.py
diffsynth/pipelines/svd_video.py
diffsynth/pipelines/wan_video.py
diffsynth/pipelines/wan_video_new.py
diffsynth/processors/FastBlend.py
diffsynth/processors/PILEditor.py
diffsynth/processors/RIFE.py
diffsynth/processors/__init__.py
diffsynth/processors/base.py
diffsynth/processors/sequencial_processor.py
diffsynth/prompters/__init__.py
diffsynth/prompters/base_prompter.py
diffsynth/prompters/cog_prompter.py
diffsynth/prompters/flux_prompter.py
diffsynth/prompters/hunyuan_dit_prompter.py
diffsynth/prompters/hunyuan_video_prompter.py
diffsynth/prompters/kolors_prompter.py
diffsynth/prompters/omnigen_prompter.py
diffsynth/prompters/omost.py
diffsynth/prompters/prompt_refiners.py
diffsynth/prompters/sd3_prompter.py
diffsynth/prompters/sd_prompter.py
diffsynth/prompters/sdxl_prompter.py
diffsynth/prompters/stepvideo_prompter.py
diffsynth/prompters/wan_prompter.py
diffsynth/schedulers/__init__.py
diffsynth/schedulers/continuous_ode.py
diffsynth/schedulers/ddim.py
diffsynth/schedulers/flow_match.py
diffsynth/tokenizer_configs/__init__.py
diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt
diffsynth/tokenizer_configs/flux/tokenizer_1/special_tokens_map.json
diffsynth/tokenizer_configs/flux/tokenizer_1/tokenizer_config.json
diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json
diffsynth/tokenizer_configs/flux/tokenizer_2/special_tokens_map.json
diffsynth/tokenizer_configs/flux/tokenizer_2/spiece.model
diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json
diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer_config.json
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/special_tokens_map.json
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/tokenizer_config.json
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/special_tokens_map.json
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer_config.json
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json
diffsynth/trainers/__init__.py
diffsynth/trainers/text_to_image.py
diffsynth/trainers/utils.py
diffsynth/vram_management/__init__.py
diffsynth/vram_management/layers.py

View File

@@ -1 +0,0 @@

View File

@@ -1,14 +0,0 @@
torch>=2.0.0
torchvision
cupy-cuda12x
transformers
controlnet-aux==0.0.7
imageio
imageio[ffmpeg]
safetensors
einops
sentencepiece
protobuf
modelscope
ftfy
pynvml

View File

@@ -1 +0,0 @@
diffsynth

View File

@@ -131,10 +131,6 @@ model_loader_configs = [
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),

View File

@@ -1,45 +0,0 @@
import torch
class GeneralLoRALoader:
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
updated_num = 0
lora_name_dict = self.get_name_dict(state_dict_lora)
for name, module in model.named_modules():
if name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
state_dict = module.state_dict()
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
module.load_state_dict(state_dict)
updated_num += 1
print(f"{updated_num} tensors are updated by LoRA.")

View File

@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
return state_dict
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device=device) as f:
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):

View File

@@ -5,10 +5,6 @@ import math
from typing import Tuple, Optional
from einops import rearrange
from .utils import hash_state_dict_keys
from dchen.camera_adapter import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
@@ -276,9 +272,6 @@ class WanModel(torch.nn.Module):
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
):
super().__init__()
self.dim = dim
@@ -310,22 +303,10 @@ class WanModel(torch.nn.Module):
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
if add_control_adapter:
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None):
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]
x = x[0].unsqueeze(0)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
@@ -551,7 +532,6 @@ class WanModelStateDictConverter:
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
# 1.3B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
@@ -566,7 +546,6 @@ class WanModelStateDictConverter:
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
# 14B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
@@ -595,74 +574,6 @@ class WanModelStateDictConverter:
"eps": 1e-6,
"has_image_pos_emb": True
}
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
# 1.3B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
# 14B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
# 1.3B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
# 14B PAI control-camera v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 32,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": False,
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
else:
config = {}
return state_dict, config

View File

@@ -774,11 +774,18 @@ class WanVideoVAE(nn.Module):
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
if tiled:
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_states, device)
return video
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
@staticmethod

Some files were not shown because too many files have changed in this diff Show More