mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #602 from modelscope/revert-601-wan-refactor
Revert "Wan refactor"
This commit is contained in:
BIN
dchen/7.png
BIN
dchen/7.png
Binary file not shown.
|
Before Width: | Height: | Size: 477 KiB |
Binary file not shown.
Binary file not shown.
@@ -1,62 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class SimpleAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
|
||||
super(SimpleAdapter, self).__init__()
|
||||
|
||||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||
|
||||
# Convolution: reduce spatial dimensions by a factor
|
||||
# of 2 (without overlap)
|
||||
self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
|
||||
|
||||
# Residual blocks for feature extraction
|
||||
self.residual_blocks = nn.Sequential(
|
||||
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Reshape to merge the frame dimension into batch
|
||||
bs, c, f, h, w = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||
|
||||
# Pixel Unshuffle operation
|
||||
x_unshuffled = self.pixel_unshuffle(x)
|
||||
|
||||
# Convolution operation
|
||||
x_conv = self.conv(x_unshuffled)
|
||||
|
||||
# Feature extraction with residual blocks
|
||||
out = self.residual_blocks(x_conv)
|
||||
|
||||
# Reshape to restore original bf dimension
|
||||
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||
|
||||
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return out
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
# Example usage
|
||||
# in_dim = 3
|
||||
# out_dim = 64
|
||||
# adapter = SimpleAdapterWithReshape(in_dim, out_dim)
|
||||
# x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
|
||||
# output = adapter(x)
|
||||
# print(output.shape) # Should reflect transformed dimensions
|
||||
@@ -1,174 +0,0 @@
|
||||
import csv
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from random import shuffle
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from decord import VideoReader
|
||||
from einops import rearrange
|
||||
from packaging import version as pver
|
||||
from PIL import Image
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
class Camera(object):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
def __init__(self, entry):
|
||||
fx, fy, cx, cy = entry[1:5]
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
||||
w2c_mat_4x4 = np.eye(4)
|
||||
w2c_mat_4x4[:3, :] = w2c_mat
|
||||
self.w2c_mat = w2c_mat_4x4
|
||||
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
||||
|
||||
def get_relative_pose(cam_params):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||
cam_to_origin = 0
|
||||
target_cam_c2w = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, -cam_to_origin],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||
ret_poses = np.array(ret_poses, dtype=np.float32)
|
||||
return ret_poses
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# c2w: B, V, 4, 4
|
||||
# K: B, V, 4
|
||||
|
||||
B = K.shape[0]
|
||||
|
||||
j, i = custom_meshgrid(
|
||||
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||
)
|
||||
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
|
||||
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||
|
||||
zs = torch.ones_like(i) # [B, HxW]
|
||||
xs = (i - cx) / fx * zs
|
||||
ys = (j - cy) / fy * zs
|
||||
zs = zs.expand_as(ys)
|
||||
|
||||
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||
|
||||
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
def custom_meshgrid(*args):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
||||
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
||||
return torch.meshgrid(*args)
|
||||
else:
|
||||
return torch.meshgrid(*args, indexing='ij')
|
||||
|
||||
|
||||
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
with open(pose_file_path, 'r') as f:
|
||||
poses = f.readlines()
|
||||
|
||||
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
||||
cam_params = [[float(x) for x in pose] for pose in poses]
|
||||
if return_poses:
|
||||
return cam_params
|
||||
else:
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray([[cam_param.fx * width,
|
||||
cam_param.fy * height,
|
||||
cam_param.cx * width,
|
||||
cam_param.cy * height]
|
||||
for cam_param in cam_params], dtype=np.float32)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray([[cam_param.fx * width,
|
||||
cam_param.fy * height,
|
||||
cam_param.cx * width,
|
||||
cam_param.cy * height]
|
||||
for cam_param in cam_params], dtype=np.float32)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
@@ -1,82 +0,0 @@
|
||||
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.018518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.037037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.05555555555555555 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.07407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.09259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.1111111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.12962962962962962 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.14814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.16666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.18518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.24074074074074073 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.25925925925925924 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2777777777777778 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.31481481481481477 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.35185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.37037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.38888888888888884 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.42592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.48148148148148145 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5185185185185185 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.537037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5740740740740741 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.611111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6296296296296295 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8148148148148148 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8333333333333334 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8703703703703705 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9259259259259258 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9814814814814815 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1111111111111112 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1296296296296298 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1666666666666667 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3148148148148149 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3703703703703702 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.425925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.462962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
@@ -1,33 +0,0 @@
|
||||
Metadata-Version: 2.2
|
||||
Name: diffsynth
|
||||
Version: 1.1.7
|
||||
Summary: Enjoy the magic of Diffusion models!
|
||||
Author: Artiprocher
|
||||
License: UNKNOWN
|
||||
Platform: UNKNOWN
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Classifier: Operating System :: OS Independent
|
||||
Requires-Python: >=3.6
|
||||
License-File: LICENSE
|
||||
Requires-Dist: torch>=2.0.0
|
||||
Requires-Dist: torchvision
|
||||
Requires-Dist: cupy-cuda12x
|
||||
Requires-Dist: transformers
|
||||
Requires-Dist: controlnet-aux==0.0.7
|
||||
Requires-Dist: imageio
|
||||
Requires-Dist: imageio[ffmpeg]
|
||||
Requires-Dist: safetensors
|
||||
Requires-Dist: einops
|
||||
Requires-Dist: sentencepiece
|
||||
Requires-Dist: protobuf
|
||||
Requires-Dist: modelscope
|
||||
Requires-Dist: ftfy
|
||||
Requires-Dist: pynvml
|
||||
Dynamic: author
|
||||
Dynamic: classifier
|
||||
Dynamic: requires-dist
|
||||
Dynamic: requires-python
|
||||
Dynamic: summary
|
||||
|
||||
UNKNOWN
|
||||
@@ -1,226 +0,0 @@
|
||||
LICENSE
|
||||
README.md
|
||||
setup.py
|
||||
diffsynth/__init__.py
|
||||
diffsynth.egg-info/PKG-INFO
|
||||
diffsynth.egg-info/SOURCES.txt
|
||||
diffsynth.egg-info/dependency_links.txt
|
||||
diffsynth.egg-info/requires.txt
|
||||
diffsynth.egg-info/top_level.txt
|
||||
diffsynth/configs/__init__.py
|
||||
diffsynth/configs/model_config.py
|
||||
diffsynth/controlnets/__init__.py
|
||||
diffsynth/controlnets/controlnet_unit.py
|
||||
diffsynth/controlnets/processors.py
|
||||
diffsynth/data/__init__.py
|
||||
diffsynth/data/simple_text_image.py
|
||||
diffsynth/data/video.py
|
||||
diffsynth/distributed/__init__.py
|
||||
diffsynth/distributed/xdit_context_parallel.py
|
||||
diffsynth/extensions/__init__.py
|
||||
diffsynth/extensions/ESRGAN/__init__.py
|
||||
diffsynth/extensions/FastBlend/__init__.py
|
||||
diffsynth/extensions/FastBlend/api.py
|
||||
diffsynth/extensions/FastBlend/cupy_kernels.py
|
||||
diffsynth/extensions/FastBlend/data.py
|
||||
diffsynth/extensions/FastBlend/patch_match.py
|
||||
diffsynth/extensions/FastBlend/runners/__init__.py
|
||||
diffsynth/extensions/FastBlend/runners/accurate.py
|
||||
diffsynth/extensions/FastBlend/runners/balanced.py
|
||||
diffsynth/extensions/FastBlend/runners/fast.py
|
||||
diffsynth/extensions/FastBlend/runners/interpolation.py
|
||||
diffsynth/extensions/ImageQualityMetric/__init__.py
|
||||
diffsynth/extensions/ImageQualityMetric/aesthetic.py
|
||||
diffsynth/extensions/ImageQualityMetric/clip.py
|
||||
diffsynth/extensions/ImageQualityMetric/config.py
|
||||
diffsynth/extensions/ImageQualityMetric/hps.py
|
||||
diffsynth/extensions/ImageQualityMetric/imagereward.py
|
||||
diffsynth/extensions/ImageQualityMetric/mps.py
|
||||
diffsynth/extensions/ImageQualityMetric/pickscore.py
|
||||
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
|
||||
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
|
||||
diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
|
||||
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
|
||||
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
|
||||
diffsynth/extensions/ImageQualityMetric/open_clip/version.py
|
||||
diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
|
||||
diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py
|
||||
diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py
|
||||
diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py
|
||||
diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py
|
||||
diffsynth/extensions/RIFE/__init__.py
|
||||
diffsynth/lora/__init__.py
|
||||
diffsynth/models/__init__.py
|
||||
diffsynth/models/attention.py
|
||||
diffsynth/models/cog_dit.py
|
||||
diffsynth/models/cog_vae.py
|
||||
diffsynth/models/downloader.py
|
||||
diffsynth/models/flux_controlnet.py
|
||||
diffsynth/models/flux_dit.py
|
||||
diffsynth/models/flux_infiniteyou.py
|
||||
diffsynth/models/flux_ipadapter.py
|
||||
diffsynth/models/flux_text_encoder.py
|
||||
diffsynth/models/flux_vae.py
|
||||
diffsynth/models/hunyuan_dit.py
|
||||
diffsynth/models/hunyuan_dit_text_encoder.py
|
||||
diffsynth/models/hunyuan_video_dit.py
|
||||
diffsynth/models/hunyuan_video_text_encoder.py
|
||||
diffsynth/models/hunyuan_video_vae_decoder.py
|
||||
diffsynth/models/hunyuan_video_vae_encoder.py
|
||||
diffsynth/models/kolors_text_encoder.py
|
||||
diffsynth/models/lora.py
|
||||
diffsynth/models/model_manager.py
|
||||
diffsynth/models/omnigen.py
|
||||
diffsynth/models/qwenvl.py
|
||||
diffsynth/models/sd3_dit.py
|
||||
diffsynth/models/sd3_text_encoder.py
|
||||
diffsynth/models/sd3_vae_decoder.py
|
||||
diffsynth/models/sd3_vae_encoder.py
|
||||
diffsynth/models/sd_controlnet.py
|
||||
diffsynth/models/sd_ipadapter.py
|
||||
diffsynth/models/sd_motion.py
|
||||
diffsynth/models/sd_text_encoder.py
|
||||
diffsynth/models/sd_unet.py
|
||||
diffsynth/models/sd_vae_decoder.py
|
||||
diffsynth/models/sd_vae_encoder.py
|
||||
diffsynth/models/sdxl_controlnet.py
|
||||
diffsynth/models/sdxl_ipadapter.py
|
||||
diffsynth/models/sdxl_motion.py
|
||||
diffsynth/models/sdxl_text_encoder.py
|
||||
diffsynth/models/sdxl_unet.py
|
||||
diffsynth/models/sdxl_vae_decoder.py
|
||||
diffsynth/models/sdxl_vae_encoder.py
|
||||
diffsynth/models/step1x_connector.py
|
||||
diffsynth/models/stepvideo_dit.py
|
||||
diffsynth/models/stepvideo_text_encoder.py
|
||||
diffsynth/models/stepvideo_vae.py
|
||||
diffsynth/models/svd_image_encoder.py
|
||||
diffsynth/models/svd_unet.py
|
||||
diffsynth/models/svd_vae_decoder.py
|
||||
diffsynth/models/svd_vae_encoder.py
|
||||
diffsynth/models/tiler.py
|
||||
diffsynth/models/utils.py
|
||||
diffsynth/models/wan_video_dit.py
|
||||
diffsynth/models/wan_video_image_encoder.py
|
||||
diffsynth/models/wan_video_motion_controller.py
|
||||
diffsynth/models/wan_video_text_encoder.py
|
||||
diffsynth/models/wan_video_vace.py
|
||||
diffsynth/models/wan_video_vae.py
|
||||
diffsynth/pipelines/__init__.py
|
||||
diffsynth/pipelines/base.py
|
||||
diffsynth/pipelines/cog_video.py
|
||||
diffsynth/pipelines/dancer.py
|
||||
diffsynth/pipelines/flux_image.py
|
||||
diffsynth/pipelines/hunyuan_image.py
|
||||
diffsynth/pipelines/hunyuan_video.py
|
||||
diffsynth/pipelines/omnigen_image.py
|
||||
diffsynth/pipelines/pipeline_runner.py
|
||||
diffsynth/pipelines/sd3_image.py
|
||||
diffsynth/pipelines/sd_image.py
|
||||
diffsynth/pipelines/sd_video.py
|
||||
diffsynth/pipelines/sdxl_image.py
|
||||
diffsynth/pipelines/sdxl_video.py
|
||||
diffsynth/pipelines/step_video.py
|
||||
diffsynth/pipelines/svd_video.py
|
||||
diffsynth/pipelines/wan_video.py
|
||||
diffsynth/pipelines/wan_video_new.py
|
||||
diffsynth/processors/FastBlend.py
|
||||
diffsynth/processors/PILEditor.py
|
||||
diffsynth/processors/RIFE.py
|
||||
diffsynth/processors/__init__.py
|
||||
diffsynth/processors/base.py
|
||||
diffsynth/processors/sequencial_processor.py
|
||||
diffsynth/prompters/__init__.py
|
||||
diffsynth/prompters/base_prompter.py
|
||||
diffsynth/prompters/cog_prompter.py
|
||||
diffsynth/prompters/flux_prompter.py
|
||||
diffsynth/prompters/hunyuan_dit_prompter.py
|
||||
diffsynth/prompters/hunyuan_video_prompter.py
|
||||
diffsynth/prompters/kolors_prompter.py
|
||||
diffsynth/prompters/omnigen_prompter.py
|
||||
diffsynth/prompters/omost.py
|
||||
diffsynth/prompters/prompt_refiners.py
|
||||
diffsynth/prompters/sd3_prompter.py
|
||||
diffsynth/prompters/sd_prompter.py
|
||||
diffsynth/prompters/sdxl_prompter.py
|
||||
diffsynth/prompters/stepvideo_prompter.py
|
||||
diffsynth/prompters/wan_prompter.py
|
||||
diffsynth/schedulers/__init__.py
|
||||
diffsynth/schedulers/continuous_ode.py
|
||||
diffsynth/schedulers/ddim.py
|
||||
diffsynth/schedulers/flow_match.py
|
||||
diffsynth/tokenizer_configs/__init__.py
|
||||
diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
|
||||
diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
|
||||
diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_1/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_1/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_2/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_2/spiece.model
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json
|
||||
diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model
|
||||
diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
|
||||
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
|
||||
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
|
||||
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt
|
||||
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt
|
||||
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json
|
||||
diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json
|
||||
diffsynth/trainers/__init__.py
|
||||
diffsynth/trainers/text_to_image.py
|
||||
diffsynth/trainers/utils.py
|
||||
diffsynth/vram_management/__init__.py
|
||||
diffsynth/vram_management/layers.py
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
torch>=2.0.0
|
||||
torchvision
|
||||
cupy-cuda12x
|
||||
transformers
|
||||
controlnet-aux==0.0.7
|
||||
imageio
|
||||
imageio[ffmpeg]
|
||||
safetensors
|
||||
einops
|
||||
sentencepiece
|
||||
protobuf
|
||||
modelscope
|
||||
ftfy
|
||||
pynvml
|
||||
@@ -1 +0,0 @@
|
||||
diffsynth
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -131,10 +131,6 @@ model_loader_configs = [
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,45 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class GeneralLoRALoader:
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
updated_num = 0
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
state_dict = module.state_dict()
|
||||
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||
module.load_state_dict(state_dict)
|
||||
updated_num += 1
|
||||
print(f"{updated_num} tensors are updated by LoRA.")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=device) as f:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -5,10 +5,6 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
from dchen.camera_adapter import SimpleAdapter
|
||||
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
@@ -276,9 +272,6 @@ class WanModel(torch.nn.Module):
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
has_image_pos_emb: bool = False,
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -310,22 +303,10 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||
if has_ref_conv:
|
||||
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
||||
self.has_image_pos_emb = has_image_pos_emb
|
||||
self.has_ref_conv = has_ref_conv
|
||||
|
||||
if add_control_adapter:
|
||||
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
else:
|
||||
self.control_adapter = None
|
||||
|
||||
def patchify(self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None):
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||
y_camera = self.control_adapter(control_camera_latents_input)
|
||||
x = [u + v for u, v in zip(x, y_camera)]
|
||||
x = x[0].unsqueeze(0)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
@@ -551,7 +532,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
# 1.3B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
@@ -566,7 +546,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
# 14B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
@@ -595,74 +574,6 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||
# 1.3B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||
# 14B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
||||
# 1.3B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
||||
# 14B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -774,11 +774,18 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_states, device)
|
||||
return video
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
videos = torch.stack(videos)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user