diff --git a/.msc b/.msc deleted file mode 100644 index eb82657..0000000 Binary files a/.msc and /dev/null differ diff --git a/.mv b/.mv deleted file mode 100644 index 8b25206..0000000 --- a/.mv +++ /dev/null @@ -1 +0,0 @@ -master \ No newline at end of file diff --git a/dchen/7.png b/dchen/7.png deleted file mode 100644 index 9a107af..0000000 Binary files a/dchen/7.png and /dev/null differ diff --git a/dchen/__pycache__/camera_adapter.cpython-310.pyc b/dchen/__pycache__/camera_adapter.cpython-310.pyc deleted file mode 100644 index c0af964..0000000 Binary files a/dchen/__pycache__/camera_adapter.cpython-310.pyc and /dev/null differ diff --git a/dchen/__pycache__/camera_compute.cpython-310.pyc b/dchen/__pycache__/camera_compute.cpython-310.pyc deleted file mode 100644 index 53d6145..0000000 Binary files a/dchen/__pycache__/camera_compute.cpython-310.pyc and /dev/null differ diff --git a/dchen/camera_adapter.py b/dchen/camera_adapter.py deleted file mode 100644 index a22c1a9..0000000 --- a/dchen/camera_adapter.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -import torch.nn as nn - -class SimpleAdapter(nn.Module): - def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): - super(SimpleAdapter, self).__init__() - - # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 - self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) - - # Convolution: reduce spatial dimensions by a factor - # of 2 (without overlap) - self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) - - # Residual blocks for feature extraction - self.residual_blocks = nn.Sequential( - *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] - ) - - def forward(self, x): - # Reshape to merge the frame dimension into batch - bs, c, f, h, w = x.size() - x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) - - # Pixel Unshuffle operation - x_unshuffled = self.pixel_unshuffle(x) - - # Convolution operation - x_conv = self.conv(x_unshuffled) - - # Feature extraction with residual blocks - out = self.residual_blocks(x_conv) - - # Reshape to restore original bf dimension - out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) - - # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames - out = out.permute(0, 2, 1, 3, 4) - - return out - -class ResidualBlock(nn.Module): - def __init__(self, dim): - super(ResidualBlock, self).__init__() - self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) - - def forward(self, x): - residual = x - out = self.relu(self.conv1(x)) - out = self.conv2(out) - out += residual - return out - -# Example usage -# in_dim = 3 -# out_dim = 64 -# adapter = SimpleAdapterWithReshape(in_dim, out_dim) -# x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4 -# output = adapter(x) -# print(output.shape) # Should reflect transformed dimensions diff --git a/dchen/camera_compute.py b/dchen/camera_compute.py deleted file mode 100644 index cec1830..0000000 --- a/dchen/camera_compute.py +++ /dev/null @@ -1,174 +0,0 @@ -import csv -import gc -import io -import json -import math -import os -import random -from random import shuffle - -import albumentations -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -import torchvision.transforms as transforms -from decord import VideoReader -from einops import rearrange -from packaging import version as pver -from PIL import Image -from torch.utils.data import BatchSampler, Sampler -from torch.utils.data.dataset import Dataset - -class Camera(object): - """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py - """ - def __init__(self, entry): - fx, fy, cx, cy = entry[1:5] - self.fx = fx - self.fy = fy - self.cx = cx - self.cy = cy - w2c_mat = np.array(entry[7:]).reshape(3, 4) - w2c_mat_4x4 = np.eye(4) - w2c_mat_4x4[:3, :] = w2c_mat - self.w2c_mat = w2c_mat_4x4 - self.c2w_mat = np.linalg.inv(w2c_mat_4x4) - -def get_relative_pose(cam_params): - """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py - """ - abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] - abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] - cam_to_origin = 0 - target_cam_c2w = np.array([ - [1, 0, 0, 0], - [0, 1, 0, -cam_to_origin], - [0, 0, 1, 0], - [0, 0, 0, 1] - ]) - abs2rel = target_cam_c2w @ abs_w2cs[0] - ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] - ret_poses = np.array(ret_poses, dtype=np.float32) - return ret_poses - -def ray_condition(K, c2w, H, W, device): - """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py - """ - # c2w: B, V, 4, 4 - # K: B, V, 4 - - B = K.shape[0] - - j, i = custom_meshgrid( - torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), - torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), - ) - i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] - j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] - - fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 - - zs = torch.ones_like(i) # [B, HxW] - xs = (i - cx) / fx * zs - ys = (j - cy) / fy * zs - zs = zs.expand_as(ys) - - directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 - directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 - - rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW - rays_o = c2w[..., :3, 3] # B, V, 3 - rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW - # c2w @ dirctions - rays_dxo = torch.cross(rays_o, rays_d) - plucker = torch.cat([rays_dxo, rays_d], dim=-1) - plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 - # plucker = plucker.permute(0, 1, 4, 2, 3) - return plucker - -def custom_meshgrid(*args): - """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py - """ - # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid - if pver.parse(torch.__version__) < pver.parse('1.10'): - return torch.meshgrid(*args) - else: - return torch.meshgrid(*args, indexing='ij') - - -def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): - """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py - """ - with open(pose_file_path, 'r') as f: - poses = f.readlines() - - poses = [pose.strip().split(' ') for pose in poses[1:]] - cam_params = [[float(x) for x in pose] for pose in poses] - if return_poses: - return cam_params - else: - cam_params = [Camera(cam_param) for cam_param in cam_params] - - sample_wh_ratio = width / height - pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed - - if pose_wh_ratio > sample_wh_ratio: - resized_ori_w = height * pose_wh_ratio - for cam_param in cam_params: - cam_param.fx = resized_ori_w * cam_param.fx / width - else: - resized_ori_h = width / pose_wh_ratio - for cam_param in cam_params: - cam_param.fy = resized_ori_h * cam_param.fy / height - - intrinsic = np.asarray([[cam_param.fx * width, - cam_param.fy * height, - cam_param.cx * width, - cam_param.cy * height] - for cam_param in cam_params], dtype=np.float32) - - K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] - c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere - c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] - plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W - plucker_embedding = plucker_embedding[None] - plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] - return plucker_embedding - - - - - - - -def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): - """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py - """ - cam_params = [Camera(cam_param) for cam_param in cam_params] - - sample_wh_ratio = width / height - pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed - - if pose_wh_ratio > sample_wh_ratio: - resized_ori_w = height * pose_wh_ratio - for cam_param in cam_params: - cam_param.fx = resized_ori_w * cam_param.fx / width - else: - resized_ori_h = width / pose_wh_ratio - for cam_param in cam_params: - cam_param.fy = resized_ori_h * cam_param.fy / height - - intrinsic = np.asarray([[cam_param.fx * width, - cam_param.fy * height, - cam_param.cx * width, - cam_param.cy * height] - for cam_param in cam_params], dtype=np.float32) - - K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] - c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere - c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] - plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W - plucker_embedding = plucker_embedding[None] - plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] - return plucker_embedding \ No newline at end of file diff --git a/dchen/camera_information.txt b/dchen/camera_information.txt deleted file mode 100644 index 3e277a9..0000000 --- a/dchen/camera_information.txt +++ /dev/null @@ -1,82 +0,0 @@ - -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.018518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.037037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.05555555555555555 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.07407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.09259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.1111111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.12962962962962962 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.14814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.16666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.18518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.24074074074074073 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.25925925925925924 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2777777777777778 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.31481481481481477 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.35185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.37037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.38888888888888884 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.42592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.48148148148148145 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5185185185185185 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.537037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5740740740740741 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.611111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6296296296296295 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8148148148148148 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8333333333333334 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8703703703703705 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9259259259259258 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9814814814814815 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1111111111111112 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1296296296296298 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1666666666666667 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3148148148148149 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3703703703703702 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.425925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.462962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 diff --git a/diffsynth.egg-info/PKG-INFO b/diffsynth.egg-info/PKG-INFO deleted file mode 100644 index 9e1c1c4..0000000 --- a/diffsynth.egg-info/PKG-INFO +++ /dev/null @@ -1,33 +0,0 @@ -Metadata-Version: 2.2 -Name: diffsynth -Version: 1.1.7 -Summary: Enjoy the magic of Diffusion models! -Author: Artiprocher -License: UNKNOWN -Platform: UNKNOWN -Classifier: Programming Language :: Python :: 3 -Classifier: License :: OSI Approved :: Apache Software License -Classifier: Operating System :: OS Independent -Requires-Python: >=3.6 -License-File: LICENSE -Requires-Dist: torch>=2.0.0 -Requires-Dist: torchvision -Requires-Dist: cupy-cuda12x -Requires-Dist: transformers -Requires-Dist: controlnet-aux==0.0.7 -Requires-Dist: imageio -Requires-Dist: imageio[ffmpeg] -Requires-Dist: safetensors -Requires-Dist: einops -Requires-Dist: sentencepiece -Requires-Dist: protobuf -Requires-Dist: modelscope -Requires-Dist: ftfy -Requires-Dist: pynvml -Dynamic: author -Dynamic: classifier -Dynamic: requires-dist -Dynamic: requires-python -Dynamic: summary - -UNKNOWN diff --git a/diffsynth.egg-info/SOURCES.txt b/diffsynth.egg-info/SOURCES.txt deleted file mode 100644 index 9d308cc..0000000 --- a/diffsynth.egg-info/SOURCES.txt +++ /dev/null @@ -1,226 +0,0 @@ -LICENSE -README.md -setup.py -diffsynth/__init__.py -diffsynth.egg-info/PKG-INFO -diffsynth.egg-info/SOURCES.txt -diffsynth.egg-info/dependency_links.txt -diffsynth.egg-info/requires.txt -diffsynth.egg-info/top_level.txt -diffsynth/configs/__init__.py -diffsynth/configs/model_config.py -diffsynth/controlnets/__init__.py -diffsynth/controlnets/controlnet_unit.py -diffsynth/controlnets/processors.py -diffsynth/data/__init__.py -diffsynth/data/simple_text_image.py -diffsynth/data/video.py -diffsynth/distributed/__init__.py -diffsynth/distributed/xdit_context_parallel.py -diffsynth/extensions/__init__.py -diffsynth/extensions/ESRGAN/__init__.py -diffsynth/extensions/FastBlend/__init__.py -diffsynth/extensions/FastBlend/api.py -diffsynth/extensions/FastBlend/cupy_kernels.py -diffsynth/extensions/FastBlend/data.py -diffsynth/extensions/FastBlend/patch_match.py -diffsynth/extensions/FastBlend/runners/__init__.py -diffsynth/extensions/FastBlend/runners/accurate.py -diffsynth/extensions/FastBlend/runners/balanced.py -diffsynth/extensions/FastBlend/runners/fast.py -diffsynth/extensions/FastBlend/runners/interpolation.py -diffsynth/extensions/ImageQualityMetric/__init__.py -diffsynth/extensions/ImageQualityMetric/aesthetic.py -diffsynth/extensions/ImageQualityMetric/clip.py -diffsynth/extensions/ImageQualityMetric/config.py -diffsynth/extensions/ImageQualityMetric/hps.py -diffsynth/extensions/ImageQualityMetric/imagereward.py -diffsynth/extensions/ImageQualityMetric/mps.py -diffsynth/extensions/ImageQualityMetric/pickscore.py -diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py -diffsynth/extensions/ImageQualityMetric/BLIP/blip.py -diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py -diffsynth/extensions/ImageQualityMetric/BLIP/med.py -diffsynth/extensions/ImageQualityMetric/BLIP/vit.py -diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py -diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py -diffsynth/extensions/ImageQualityMetric/open_clip/constants.py -diffsynth/extensions/ImageQualityMetric/open_clip/factory.py -diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py -diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py -diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py -diffsynth/extensions/ImageQualityMetric/open_clip/loss.py -diffsynth/extensions/ImageQualityMetric/open_clip/model.py -diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py -diffsynth/extensions/ImageQualityMetric/open_clip/openai.py -diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py -diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py -diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py -diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py -diffsynth/extensions/ImageQualityMetric/open_clip/transform.py -diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py -diffsynth/extensions/ImageQualityMetric/open_clip/utils.py -diffsynth/extensions/ImageQualityMetric/open_clip/version.py -diffsynth/extensions/ImageQualityMetric/trainer/__init__.py -diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py -diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py -diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py -diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py -diffsynth/extensions/RIFE/__init__.py -diffsynth/lora/__init__.py -diffsynth/models/__init__.py -diffsynth/models/attention.py -diffsynth/models/cog_dit.py -diffsynth/models/cog_vae.py -diffsynth/models/downloader.py -diffsynth/models/flux_controlnet.py -diffsynth/models/flux_dit.py -diffsynth/models/flux_infiniteyou.py -diffsynth/models/flux_ipadapter.py -diffsynth/models/flux_text_encoder.py -diffsynth/models/flux_vae.py -diffsynth/models/hunyuan_dit.py -diffsynth/models/hunyuan_dit_text_encoder.py -diffsynth/models/hunyuan_video_dit.py -diffsynth/models/hunyuan_video_text_encoder.py -diffsynth/models/hunyuan_video_vae_decoder.py -diffsynth/models/hunyuan_video_vae_encoder.py -diffsynth/models/kolors_text_encoder.py -diffsynth/models/lora.py -diffsynth/models/model_manager.py -diffsynth/models/omnigen.py -diffsynth/models/qwenvl.py -diffsynth/models/sd3_dit.py -diffsynth/models/sd3_text_encoder.py -diffsynth/models/sd3_vae_decoder.py -diffsynth/models/sd3_vae_encoder.py -diffsynth/models/sd_controlnet.py -diffsynth/models/sd_ipadapter.py -diffsynth/models/sd_motion.py -diffsynth/models/sd_text_encoder.py -diffsynth/models/sd_unet.py -diffsynth/models/sd_vae_decoder.py -diffsynth/models/sd_vae_encoder.py -diffsynth/models/sdxl_controlnet.py -diffsynth/models/sdxl_ipadapter.py -diffsynth/models/sdxl_motion.py -diffsynth/models/sdxl_text_encoder.py -diffsynth/models/sdxl_unet.py -diffsynth/models/sdxl_vae_decoder.py -diffsynth/models/sdxl_vae_encoder.py -diffsynth/models/step1x_connector.py -diffsynth/models/stepvideo_dit.py -diffsynth/models/stepvideo_text_encoder.py -diffsynth/models/stepvideo_vae.py -diffsynth/models/svd_image_encoder.py -diffsynth/models/svd_unet.py -diffsynth/models/svd_vae_decoder.py -diffsynth/models/svd_vae_encoder.py -diffsynth/models/tiler.py -diffsynth/models/utils.py -diffsynth/models/wan_video_dit.py -diffsynth/models/wan_video_image_encoder.py -diffsynth/models/wan_video_motion_controller.py -diffsynth/models/wan_video_text_encoder.py -diffsynth/models/wan_video_vace.py -diffsynth/models/wan_video_vae.py -diffsynth/pipelines/__init__.py -diffsynth/pipelines/base.py -diffsynth/pipelines/cog_video.py -diffsynth/pipelines/dancer.py -diffsynth/pipelines/flux_image.py -diffsynth/pipelines/hunyuan_image.py -diffsynth/pipelines/hunyuan_video.py -diffsynth/pipelines/omnigen_image.py -diffsynth/pipelines/pipeline_runner.py -diffsynth/pipelines/sd3_image.py -diffsynth/pipelines/sd_image.py -diffsynth/pipelines/sd_video.py -diffsynth/pipelines/sdxl_image.py -diffsynth/pipelines/sdxl_video.py -diffsynth/pipelines/step_video.py -diffsynth/pipelines/svd_video.py -diffsynth/pipelines/wan_video.py -diffsynth/pipelines/wan_video_new.py -diffsynth/processors/FastBlend.py -diffsynth/processors/PILEditor.py -diffsynth/processors/RIFE.py -diffsynth/processors/__init__.py -diffsynth/processors/base.py -diffsynth/processors/sequencial_processor.py -diffsynth/prompters/__init__.py -diffsynth/prompters/base_prompter.py -diffsynth/prompters/cog_prompter.py -diffsynth/prompters/flux_prompter.py -diffsynth/prompters/hunyuan_dit_prompter.py -diffsynth/prompters/hunyuan_video_prompter.py -diffsynth/prompters/kolors_prompter.py -diffsynth/prompters/omnigen_prompter.py -diffsynth/prompters/omost.py -diffsynth/prompters/prompt_refiners.py -diffsynth/prompters/sd3_prompter.py -diffsynth/prompters/sd_prompter.py -diffsynth/prompters/sdxl_prompter.py -diffsynth/prompters/stepvideo_prompter.py -diffsynth/prompters/wan_prompter.py -diffsynth/schedulers/__init__.py -diffsynth/schedulers/continuous_ode.py -diffsynth/schedulers/ddim.py -diffsynth/schedulers/flow_match.py -diffsynth/tokenizer_configs/__init__.py -diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json -diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json -diffsynth/tokenizer_configs/cog/tokenizer/spiece.model -diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json -diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt -diffsynth/tokenizer_configs/flux/tokenizer_1/special_tokens_map.json -diffsynth/tokenizer_configs/flux/tokenizer_1/tokenizer_config.json -diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json -diffsynth/tokenizer_configs/flux/tokenizer_2/special_tokens_map.json -diffsynth/tokenizer_configs/flux/tokenizer_2/spiece.model -diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json -diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer_config.json -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model -diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/special_tokens_map.json -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/tokenizer_config.json -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/special_tokens_map.json -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json -diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer_config.json -diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model -diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json -diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt -diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt -diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json -diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json -diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json -diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json -diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt -diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json -diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json -diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json -diffsynth/trainers/__init__.py -diffsynth/trainers/text_to_image.py -diffsynth/trainers/utils.py -diffsynth/vram_management/__init__.py -diffsynth/vram_management/layers.py \ No newline at end of file diff --git a/diffsynth.egg-info/dependency_links.txt b/diffsynth.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/diffsynth.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/diffsynth.egg-info/requires.txt b/diffsynth.egg-info/requires.txt deleted file mode 100644 index 92d8b48..0000000 --- a/diffsynth.egg-info/requires.txt +++ /dev/null @@ -1,14 +0,0 @@ -torch>=2.0.0 -torchvision -cupy-cuda12x -transformers -controlnet-aux==0.0.7 -imageio -imageio[ffmpeg] -safetensors -einops -sentencepiece -protobuf -modelscope -ftfy -pynvml diff --git a/diffsynth.egg-info/top_level.txt b/diffsynth.egg-info/top_level.txt deleted file mode 100644 index a4c845e..0000000 --- a/diffsynth.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -diffsynth diff --git a/diffsynth/__pycache__/__init__.cpython-310.pyc b/diffsynth/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index aaeba36..0000000 Binary files a/diffsynth/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/configs/__pycache__/__init__.cpython-310.pyc b/diffsynth/configs/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 9100616..0000000 Binary files a/diffsynth/configs/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/configs/__pycache__/model_config.cpython-310.pyc b/diffsynth/configs/__pycache__/model_config.cpython-310.pyc deleted file mode 100644 index d2e9340..0000000 Binary files a/diffsynth/configs/__pycache__/model_config.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index ff485f0..6bb9350 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -131,10 +131,6 @@ model_loader_configs = [ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"), - (None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"), - (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"), - (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), - (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), diff --git a/diffsynth/controlnets/__pycache__/__init__.cpython-310.pyc b/diffsynth/controlnets/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 9ae16d4..0000000 Binary files a/diffsynth/controlnets/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-310.pyc b/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-310.pyc deleted file mode 100644 index 4468d93..0000000 Binary files a/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/controlnets/__pycache__/processors.cpython-310.pyc b/diffsynth/controlnets/__pycache__/processors.cpython-310.pyc deleted file mode 100644 index 6a2babc..0000000 Binary files a/diffsynth/controlnets/__pycache__/processors.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/data/__pycache__/__init__.cpython-310.pyc b/diffsynth/data/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index a73cba5..0000000 Binary files a/diffsynth/data/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/data/__pycache__/video.cpython-310.pyc b/diffsynth/data/__pycache__/video.cpython-310.pyc deleted file mode 100644 index d77c90a..0000000 Binary files a/diffsynth/data/__pycache__/video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc b/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 4a2585d..0000000 Binary files a/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-310.pyc b/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index caea266..0000000 Binary files a/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/extensions/__pycache__/__init__.cpython-310.pyc b/diffsynth/extensions/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index b8d8be0..0000000 Binary files a/diffsynth/extensions/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/lora/__init__.py b/diffsynth/lora/__init__.py deleted file mode 100644 index 33bd89c..0000000 --- a/diffsynth/lora/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch - - - -class GeneralLoRALoader: - def __init__(self, device="cpu", torch_dtype=torch.float32): - self.device = device - self.torch_dtype = torch_dtype - - - def get_name_dict(self, lora_state_dict): - lora_name_dict = {} - for key in lora_state_dict: - if ".lora_B." not in key: - continue - keys = key.split(".") - if len(keys) > keys.index("lora_B") + 2: - keys.pop(keys.index("lora_B") + 1) - keys.pop(keys.index("lora_B")) - if keys[0] == "diffusion_model": - keys.pop(0) - keys.pop(-1) - target_name = ".".join(keys) - lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) - return lora_name_dict - - - def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): - updated_num = 0 - lora_name_dict = self.get_name_dict(state_dict_lora) - for name, module in model.named_modules(): - if name in lora_name_dict: - weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) - weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2) - weight_down = weight_down.squeeze(3).squeeze(2) - weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - weight_lora = alpha * torch.mm(weight_up, weight_down) - state_dict = module.state_dict() - state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora - module.load_state_dict(state_dict) - updated_num += 1 - print(f"{updated_num} tensors are updated by LoRA.") diff --git a/diffsynth/lora/__pycache__/__init__.cpython-310.pyc b/diffsynth/lora/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 76d6821..0000000 Binary files a/diffsynth/lora/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/__init__.cpython-310.pyc b/diffsynth/models/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index a785d62..0000000 Binary files a/diffsynth/models/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/attention.cpython-310.pyc b/diffsynth/models/__pycache__/attention.cpython-310.pyc deleted file mode 100644 index 7f32e01..0000000 Binary files a/diffsynth/models/__pycache__/attention.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/cog_dit.cpython-310.pyc b/diffsynth/models/__pycache__/cog_dit.cpython-310.pyc deleted file mode 100644 index 014f774..0000000 Binary files a/diffsynth/models/__pycache__/cog_dit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/cog_vae.cpython-310.pyc b/diffsynth/models/__pycache__/cog_vae.cpython-310.pyc deleted file mode 100644 index afdb288..0000000 Binary files a/diffsynth/models/__pycache__/cog_vae.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/downloader.cpython-310.pyc b/diffsynth/models/__pycache__/downloader.cpython-310.pyc deleted file mode 100644 index 63a8f91..0000000 Binary files a/diffsynth/models/__pycache__/downloader.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/flux_controlnet.cpython-310.pyc b/diffsynth/models/__pycache__/flux_controlnet.cpython-310.pyc deleted file mode 100644 index 3b8a12b..0000000 Binary files a/diffsynth/models/__pycache__/flux_controlnet.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/flux_dit.cpython-310.pyc b/diffsynth/models/__pycache__/flux_dit.cpython-310.pyc deleted file mode 100644 index 453f82a..0000000 Binary files a/diffsynth/models/__pycache__/flux_dit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/flux_infiniteyou.cpython-310.pyc b/diffsynth/models/__pycache__/flux_infiniteyou.cpython-310.pyc deleted file mode 100644 index b70580c..0000000 Binary files a/diffsynth/models/__pycache__/flux_infiniteyou.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/flux_ipadapter.cpython-310.pyc b/diffsynth/models/__pycache__/flux_ipadapter.cpython-310.pyc deleted file mode 100644 index fa2ca0e..0000000 Binary files a/diffsynth/models/__pycache__/flux_ipadapter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/flux_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/flux_text_encoder.cpython-310.pyc deleted file mode 100644 index 7c4c9fd..0000000 Binary files a/diffsynth/models/__pycache__/flux_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/flux_vae.cpython-310.pyc b/diffsynth/models/__pycache__/flux_vae.cpython-310.pyc deleted file mode 100644 index 8ef63de..0000000 Binary files a/diffsynth/models/__pycache__/flux_vae.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/hunyuan_dit.cpython-310.pyc b/diffsynth/models/__pycache__/hunyuan_dit.cpython-310.pyc deleted file mode 100644 index bd7c96e..0000000 Binary files a/diffsynth/models/__pycache__/hunyuan_dit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-310.pyc deleted file mode 100644 index 91f1f7f..0000000 Binary files a/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-310.pyc b/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-310.pyc deleted file mode 100644 index a4c5b63..0000000 Binary files a/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-310.pyc deleted file mode 100644 index 2ec0513..0000000 Binary files a/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-310.pyc b/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-310.pyc deleted file mode 100644 index a23f94c..0000000 Binary files a/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-310.pyc deleted file mode 100644 index 5870a79..0000000 Binary files a/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/kolors_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/kolors_text_encoder.cpython-310.pyc deleted file mode 100644 index 207a660..0000000 Binary files a/diffsynth/models/__pycache__/kolors_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/lora.cpython-310.pyc b/diffsynth/models/__pycache__/lora.cpython-310.pyc deleted file mode 100644 index 8ccec57..0000000 Binary files a/diffsynth/models/__pycache__/lora.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/model_manager.cpython-310.pyc b/diffsynth/models/__pycache__/model_manager.cpython-310.pyc deleted file mode 100644 index dee3901..0000000 Binary files a/diffsynth/models/__pycache__/model_manager.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/omnigen.cpython-310.pyc b/diffsynth/models/__pycache__/omnigen.cpython-310.pyc deleted file mode 100644 index 9d3ec4c..0000000 Binary files a/diffsynth/models/__pycache__/omnigen.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd3_dit.cpython-310.pyc b/diffsynth/models/__pycache__/sd3_dit.cpython-310.pyc deleted file mode 100644 index 51cfbdc..0000000 Binary files a/diffsynth/models/__pycache__/sd3_dit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc deleted file mode 100644 index 88687fb..0000000 Binary files a/diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-310.pyc b/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-310.pyc deleted file mode 100644 index 1e9b70b..0000000 Binary files a/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-310.pyc deleted file mode 100644 index 14d53bf..0000000 Binary files a/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd_controlnet.cpython-310.pyc b/diffsynth/models/__pycache__/sd_controlnet.cpython-310.pyc deleted file mode 100644 index 629d988..0000000 Binary files a/diffsynth/models/__pycache__/sd_controlnet.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd_ipadapter.cpython-310.pyc b/diffsynth/models/__pycache__/sd_ipadapter.cpython-310.pyc deleted file mode 100644 index 550c8a2..0000000 Binary files a/diffsynth/models/__pycache__/sd_ipadapter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd_motion.cpython-310.pyc b/diffsynth/models/__pycache__/sd_motion.cpython-310.pyc deleted file mode 100644 index 31d6980..0000000 Binary files a/diffsynth/models/__pycache__/sd_motion.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/sd_text_encoder.cpython-310.pyc deleted file mode 100644 index 84cb35c..0000000 Binary files a/diffsynth/models/__pycache__/sd_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd_unet.cpython-310.pyc b/diffsynth/models/__pycache__/sd_unet.cpython-310.pyc deleted file mode 100644 index 906ddae..0000000 Binary files a/diffsynth/models/__pycache__/sd_unet.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd_vae_decoder.cpython-310.pyc b/diffsynth/models/__pycache__/sd_vae_decoder.cpython-310.pyc deleted file mode 100644 index 0ef2104..0000000 Binary files a/diffsynth/models/__pycache__/sd_vae_decoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sd_vae_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/sd_vae_encoder.cpython-310.pyc deleted file mode 100644 index 88a665a..0000000 Binary files a/diffsynth/models/__pycache__/sd_vae_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sdxl_controlnet.cpython-310.pyc b/diffsynth/models/__pycache__/sdxl_controlnet.cpython-310.pyc deleted file mode 100644 index 46b32b7..0000000 Binary files a/diffsynth/models/__pycache__/sdxl_controlnet.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-310.pyc b/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-310.pyc deleted file mode 100644 index 90ae52c..0000000 Binary files a/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sdxl_motion.cpython-310.pyc b/diffsynth/models/__pycache__/sdxl_motion.cpython-310.pyc deleted file mode 100644 index 2388220..0000000 Binary files a/diffsynth/models/__pycache__/sdxl_motion.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-310.pyc deleted file mode 100644 index c30896e..0000000 Binary files a/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc b/diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc deleted file mode 100644 index b507240..0000000 Binary files a/diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-310.pyc b/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-310.pyc deleted file mode 100644 index 83e43c1..0000000 Binary files a/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-310.pyc deleted file mode 100644 index f31abe6..0000000 Binary files a/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/step1x_connector.cpython-310.pyc b/diffsynth/models/__pycache__/step1x_connector.cpython-310.pyc deleted file mode 100644 index 9e00094..0000000 Binary files a/diffsynth/models/__pycache__/step1x_connector.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/stepvideo_dit.cpython-310.pyc b/diffsynth/models/__pycache__/stepvideo_dit.cpython-310.pyc deleted file mode 100644 index c5fd431..0000000 Binary files a/diffsynth/models/__pycache__/stepvideo_dit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-310.pyc deleted file mode 100644 index d6d1b88..0000000 Binary files a/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/stepvideo_vae.cpython-310.pyc b/diffsynth/models/__pycache__/stepvideo_vae.cpython-310.pyc deleted file mode 100644 index 8b6264e..0000000 Binary files a/diffsynth/models/__pycache__/stepvideo_vae.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/svd_image_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/svd_image_encoder.cpython-310.pyc deleted file mode 100644 index c2a17d6..0000000 Binary files a/diffsynth/models/__pycache__/svd_image_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/svd_unet.cpython-310.pyc b/diffsynth/models/__pycache__/svd_unet.cpython-310.pyc deleted file mode 100644 index 211e031..0000000 Binary files a/diffsynth/models/__pycache__/svd_unet.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/svd_vae_decoder.cpython-310.pyc b/diffsynth/models/__pycache__/svd_vae_decoder.cpython-310.pyc deleted file mode 100644 index aff5c83..0000000 Binary files a/diffsynth/models/__pycache__/svd_vae_decoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/svd_vae_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/svd_vae_encoder.cpython-310.pyc deleted file mode 100644 index df39468..0000000 Binary files a/diffsynth/models/__pycache__/svd_vae_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/tiler.cpython-310.pyc b/diffsynth/models/__pycache__/tiler.cpython-310.pyc deleted file mode 100644 index 2297791..0000000 Binary files a/diffsynth/models/__pycache__/tiler.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/utils.cpython-310.pyc b/diffsynth/models/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index c136265..0000000 Binary files a/diffsynth/models/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/wan_video_dit.cpython-310.pyc b/diffsynth/models/__pycache__/wan_video_dit.cpython-310.pyc deleted file mode 100644 index bf3b099..0000000 Binary files a/diffsynth/models/__pycache__/wan_video_dit.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-310.pyc deleted file mode 100644 index b5af93a..0000000 Binary files a/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-310.pyc b/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-310.pyc deleted file mode 100644 index 65bebb5..0000000 Binary files a/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-310.pyc b/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-310.pyc deleted file mode 100644 index 4f75543..0000000 Binary files a/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/wan_video_vace.cpython-310.pyc b/diffsynth/models/__pycache__/wan_video_vace.cpython-310.pyc deleted file mode 100644 index 9a47c1d..0000000 Binary files a/diffsynth/models/__pycache__/wan_video_vace.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/__pycache__/wan_video_vae.cpython-310.pyc b/diffsynth/models/__pycache__/wan_video_vae.cpython-310.pyc deleted file mode 100644 index 75921a4..0000000 Binary files a/diffsynth/models/__pycache__/wan_video_vae.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py index 0d58e4e..99f5dee 100644 --- a/diffsynth/models/utils.py +++ b/diffsynth/models/utils.py @@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None): return state_dict -def load_state_dict(file_path, torch_dtype=None, device="cpu"): +def load_state_dict(file_path, torch_dtype=None): if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) -def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): +def load_state_dict_from_safetensors(file_path, torch_dtype=None): state_dict = {} - with safe_open(file_path, framework="pt", device=device) as f: + with safe_open(file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: @@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): return state_dict -def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): - state_dict = torch.load(file_path, map_location=device, weights_only=True) +def load_state_dict_from_bin(file_path, torch_dtype=None): + state_dict = torch.load(file_path, map_location="cpu", weights_only=True) if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index ab63bfb..d9be8ab 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,10 +5,6 @@ import math from typing import Tuple, Optional from einops import rearrange from .utils import hash_state_dict_keys - -from dchen.camera_adapter import SimpleAdapter - - try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True @@ -276,9 +272,6 @@ class WanModel(torch.nn.Module): num_layers: int, has_image_input: bool, has_image_pos_emb: bool = False, - has_ref_conv: bool = False, - add_control_adapter: bool = False, - in_dim_control_adapter: int = 24, ): super().__init__() self.dim = dim @@ -310,22 +303,10 @@ class WanModel(torch.nn.Module): if has_image_input: self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 - if has_ref_conv: - self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) self.has_image_pos_emb = has_image_pos_emb - self.has_ref_conv = has_ref_conv - if add_control_adapter: - self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) - else: - self.control_adapter = None - - def patchify(self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None): + def patchify(self, x: torch.Tensor): x = self.patch_embedding(x) - if self.control_adapter is not None and control_camera_latents_input is not None: - y_camera = self.control_adapter(control_camera_latents_input) - x = [u + v for u, v in zip(x, y_camera)] - x = x[0].unsqueeze(0) grid_size = x.shape[2:] x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() return x, grid_size # x, grid_size: (f, h, w) @@ -551,7 +532,6 @@ class WanModelStateDictConverter: "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": - # 1.3B PAI control config = { "has_image_input": True, "patch_size": [1, 2, 2], @@ -566,7 +546,6 @@ class WanModelStateDictConverter: "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": - # 14B PAI control config = { "has_image_input": True, "patch_size": [1, 2, 2], @@ -595,74 +574,6 @@ class WanModelStateDictConverter: "eps": 1e-6, "has_image_pos_emb": True } - elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504": - # 1.3B PAI control v1.1 - config = { - "has_image_input": True, - "patch_size": [1, 2, 2], - "in_dim": 48, - "dim": 1536, - "ffn_dim": 8960, - "freq_dim": 256, - "text_dim": 4096, - "out_dim": 16, - "num_heads": 12, - "num_layers": 30, - "eps": 1e-6, - "has_ref_conv": True - } - elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b": - # 14B PAI control v1.1 - config = { - "has_image_input": True, - "patch_size": [1, 2, 2], - "in_dim": 48, - "dim": 5120, - "ffn_dim": 13824, - "freq_dim": 256, - "text_dim": 4096, - "out_dim": 16, - "num_heads": 40, - "num_layers": 40, - "eps": 1e-6, - "has_ref_conv": True - } - elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901": - # 1.3B PAI control-camera v1.1 - config = { - "has_image_input": True, - "patch_size": [1, 2, 2], - "in_dim": 32, - "dim": 1536, - "ffn_dim": 8960, - "freq_dim": 256, - "text_dim": 4096, - "out_dim": 16, - "num_heads": 12, - "num_layers": 30, - "eps": 1e-6, - "has_ref_conv": False, - "add_control_adapter": True, - "in_dim_control_adapter": 24, - } - elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae": - # 14B PAI control-camera v1.1 - config = { - "has_image_input": True, - "patch_size": [1, 2, 2], - "in_dim": 32, - "dim": 5120, - "ffn_dim": 13824, - "freq_dim": 256, - "text_dim": 4096, - "out_dim": 16, - "num_heads": 40, - "num_layers": 40, - "eps": 1e-6, - "has_ref_conv": False, - "add_control_adapter": True, - "in_dim_control_adapter": 24, - } else: config = {} return state_dict, config diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 137fd28..df23076 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -774,11 +774,18 @@ class WanVideoVAE(nn.Module): def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): - if tiled: - video = self.tiled_decode(hidden_states, device, tile_size, tile_stride) - else: - video = self.single_decode(hidden_states, device) - return video + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos @staticmethod diff --git a/diffsynth/pipelines/__pycache__/__init__.cpython-310.pyc b/diffsynth/pipelines/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 9c917b0..0000000 Binary files a/diffsynth/pipelines/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/base.cpython-310.pyc b/diffsynth/pipelines/__pycache__/base.cpython-310.pyc deleted file mode 100644 index f26b6b1..0000000 Binary files a/diffsynth/pipelines/__pycache__/base.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/cog_video.cpython-310.pyc b/diffsynth/pipelines/__pycache__/cog_video.cpython-310.pyc deleted file mode 100644 index d8b0ac9..0000000 Binary files a/diffsynth/pipelines/__pycache__/cog_video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/dancer.cpython-310.pyc b/diffsynth/pipelines/__pycache__/dancer.cpython-310.pyc deleted file mode 100644 index 01d2488..0000000 Binary files a/diffsynth/pipelines/__pycache__/dancer.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/flux_image.cpython-310.pyc b/diffsynth/pipelines/__pycache__/flux_image.cpython-310.pyc deleted file mode 100644 index c3ce9b4..0000000 Binary files a/diffsynth/pipelines/__pycache__/flux_image.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-310.pyc b/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-310.pyc deleted file mode 100644 index 12f9f95..0000000 Binary files a/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-310.pyc b/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-310.pyc deleted file mode 100644 index 6a8ff5e..0000000 Binary files a/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/omnigen_image.cpython-310.pyc b/diffsynth/pipelines/__pycache__/omnigen_image.cpython-310.pyc deleted file mode 100644 index 95b912f..0000000 Binary files a/diffsynth/pipelines/__pycache__/omnigen_image.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-310.pyc b/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-310.pyc deleted file mode 100644 index cedc1e4..0000000 Binary files a/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/sd3_image.cpython-310.pyc b/diffsynth/pipelines/__pycache__/sd3_image.cpython-310.pyc deleted file mode 100644 index 3b15ff2..0000000 Binary files a/diffsynth/pipelines/__pycache__/sd3_image.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/sd_image.cpython-310.pyc b/diffsynth/pipelines/__pycache__/sd_image.cpython-310.pyc deleted file mode 100644 index aa89749..0000000 Binary files a/diffsynth/pipelines/__pycache__/sd_image.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/sd_video.cpython-310.pyc b/diffsynth/pipelines/__pycache__/sd_video.cpython-310.pyc deleted file mode 100644 index e57d7be..0000000 Binary files a/diffsynth/pipelines/__pycache__/sd_video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/sdxl_image.cpython-310.pyc b/diffsynth/pipelines/__pycache__/sdxl_image.cpython-310.pyc deleted file mode 100644 index c61c4af..0000000 Binary files a/diffsynth/pipelines/__pycache__/sdxl_image.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/sdxl_video.cpython-310.pyc b/diffsynth/pipelines/__pycache__/sdxl_video.cpython-310.pyc deleted file mode 100644 index 9b445fb..0000000 Binary files a/diffsynth/pipelines/__pycache__/sdxl_video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/step_video.cpython-310.pyc b/diffsynth/pipelines/__pycache__/step_video.cpython-310.pyc deleted file mode 100644 index 8c86b2b..0000000 Binary files a/diffsynth/pipelines/__pycache__/step_video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/svd_video.cpython-310.pyc b/diffsynth/pipelines/__pycache__/svd_video.cpython-310.pyc deleted file mode 100644 index 6471e96..0000000 Binary files a/diffsynth/pipelines/__pycache__/svd_video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/wan_video.cpython-310.pyc b/diffsynth/pipelines/__pycache__/wan_video.cpython-310.pyc deleted file mode 100644 index aeca1d6..0000000 Binary files a/diffsynth/pipelines/__pycache__/wan_video.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/__pycache__/wan_video_new.cpython-310.pyc b/diffsynth/pipelines/__pycache__/wan_video_new.cpython-310.pyc deleted file mode 100644 index 31d0868..0000000 Binary files a/diffsynth/pipelines/__pycache__/wan_video_new.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index b84b1b9..77835a4 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -68,7 +68,6 @@ class WanVideoPipeline(BasePipeline): torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, - torch.nn.Conv2d: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, @@ -238,18 +237,6 @@ class WanVideoPipeline(BasePipeline): return latents - def prepare_reference_image(self, reference_image, height, width): - if reference_image is not None: - self.load_models_to_device(["vae"]) - reference_image = reference_image.resize((width, height)) - reference_image = self.preprocess_images([reference_image]) - reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device) - reference_latents = self.vae.encode(reference_image, device=self.device) - return {"reference_latents": reference_latents} - else: - return {} - - def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): if control_video is not None: control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) @@ -352,7 +339,6 @@ class WanVideoPipeline(BasePipeline): end_image=None, input_video=None, control_video=None, - reference_image=None, vace_video=None, vace_video_mask=None, vace_reference_image=None, @@ -412,9 +398,6 @@ class WanVideoPipeline(BasePipeline): else: image_emb = {} - # Reference image - reference_image_kwargs = self.prepare_reference_image(reference_image, height, width) - # ControlNet if control_video is not None: self.load_models_to_device(["image_encoder", "vae"]) @@ -452,14 +435,14 @@ class WanVideoPipeline(BasePipeline): self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, - **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs, + **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, ) if cfg_scale != 1.0: noise_pred_nega = model_fn_wan_video( self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, - **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs, + **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -543,7 +526,6 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, - reference_latents = None, vace_context = None, vace_scale = 1.0, tea_cache: TeaCache = None, @@ -570,12 +552,6 @@ def model_fn_wan_video( x, (f, h, w) = dit.patchify(x) - # Reference image - if reference_latents is not None: - reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2) - x = torch.concat([reference_latents, x], dim=1) - f += 1 - freqs = torch.cat([ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), @@ -604,10 +580,6 @@ def model_fn_wan_video( x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale if tea_cache is not None: tea_cache.store(x) - - if reference_latents is not None: - x = x[:, reference_latents.shape[1]:] - f -= 1 x = dit.head(x, t) if use_unified_sequence_parallel: diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py deleted file mode 100644 index 8c146b9..0000000 --- a/diffsynth/pipelines/wan_video_new.py +++ /dev/null @@ -1,1114 +0,0 @@ -import torch, warnings, glob, os -import numpy as np -from PIL import Image -from einops import repeat, reduce -from typing import Optional, Union -from dataclasses import dataclass -from modelscope import snapshot_download -from einops import rearrange -import numpy as np -from PIL import Image -from tqdm import tqdm -from typing import Optional - -from ..models import ModelManager, load_state_dict -from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d -from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm -from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample -from ..models.wan_video_image_encoder import WanImageEncoder -from ..models.wan_video_vace import VaceWanModel -from ..models.wan_video_motion_controller import WanMotionControllerModel -from ..schedulers.flow_match import FlowMatchScheduler -from ..prompters import WanPrompter -from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm -from ..lora import GeneralLoRALoader - - -class BasePipeline(torch.nn.Module): - - def __init__( - self, - device="cuda", torch_dtype=torch.float16, - height_division_factor=64, width_division_factor=64, - time_division_factor=None, time_division_remainder=None, - ): - super().__init__() - # The device and torch_dtype is used for the storage of intermediate variables, not models. - self.device = device - self.torch_dtype = torch_dtype - # The following parameters are used for shape check. - self.height_division_factor = height_division_factor - self.width_division_factor = width_division_factor - self.time_division_factor = time_division_factor - self.time_division_remainder = time_division_remainder - self.vram_management_enabled = False - - - def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None: - self.device = device - if dtype is not None: - self.torch_dtype = dtype - super().to(*args, **kwargs) - return self - - - def check_resize_height_width(self, height, width, num_frames=None): - # Shape check - if height % self.height_division_factor != 0: - height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor - print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") - if width % self.width_division_factor != 0: - width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor - print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") - if num_frames is None: - return height, width - else: - if num_frames % self.time_division_factor != self.time_division_remainder: - num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder - print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") - return height, width, num_frames - - - def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): - # Transform a PIL.Image to torch.Tensor - image = torch.Tensor(np.array(image, dtype=np.float32)) - image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) - image = image * ((max_value - min_value) / 255) + min_value - image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) - return image - - - def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): - # Transform a list of PIL.Image to torch.Tensor - video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] - video = torch.stack(video, dim=pattern.index("T") // 2) - return video - - - def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): - # Transform a torch.Tensor to PIL.Image - if pattern != "H W C": - vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") - image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) - image = image.to(device="cpu", dtype=torch.uint8) - image = Image.fromarray(image.numpy()) - return image - - - def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): - # Transform a torch.Tensor to list of PIL.Image - if pattern != "T H W C": - vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") - video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] - return video - - - def load_models_to_device(self, model_names=[]): - if self.vram_management_enabled: - # offload models - for name, model in self.named_children(): - if name not in model_names: - if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: - for module in model.modules(): - if hasattr(module, "offload"): - module.offload() - else: - model.cpu() - torch.cuda.empty_cache() - # onload models - for name, model in self.named_children(): - if name in model_names: - if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: - for module in model.modules(): - if hasattr(module, "onload"): - module.onload() - else: - model.to(self.device) - - - def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): - # Initialize Gaussian noise - generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) - noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) - return noise - - - def enable_cpu_offload(self): - warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") - self.vram_management_enabled = True - - - def get_free_vram(self): - total_memory = torch.cuda.get_device_properties(self.device).total_memory - allocated_memory = torch.cuda.device_memory_used(self.device) - return (total_memory - allocated_memory) / (1024 ** 3) - - - def freeze_except(self, model_names): - for name, model in self.named_children(): - if name in model_names: - model.train() - model.requires_grad_(True) - else: - model.eval() - model.requires_grad_(False) - - -@dataclass -class ModelConfig: - path: Union[str, list[str]] = None - model_id: str = None - origin_file_pattern: Union[str, list[str]] = None - download_resource: str = "ModelScope" - offload_device: Optional[Union[str, torch.device]] = None - offload_dtype: Optional[torch.dtype] = None - - def download_if_necessary(self, local_model_path="./models", skip_download=False): - if self.path is None: - if self.model_id is None or self.origin_file_pattern is None: - raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") - if not skip_download: - downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) - snapshot_download( - self.model_id, - local_dir=os.path.join(local_model_path, self.model_id), - allow_file_pattern=self.origin_file_pattern, - ignore_file_pattern=downloaded_files, - local_files_only=False - ) - self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) - if isinstance(self.path, list) and len(self.path) == 1: - self.path = self.path[0] - - -class WanVideoPipeline(BasePipeline): - - def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): - super().__init__( - device=device, torch_dtype=torch_dtype, - height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 - ) - self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) - self.prompter = WanPrompter(tokenizer_path=tokenizer_path) - self.text_encoder: WanTextEncoder = None - self.image_encoder: WanImageEncoder = None - self.dit: WanModel = None - self.vae: WanVideoVAE = None - self.motion_controller: WanMotionControllerModel = None - self.vace: VaceWanModel = None - self.in_iteration_models = ("dit", "motion_controller", "vace") - self.unit_runner = PipelineUnitRunner() - self.units = [ - WanVideoUnit_ShapeChecker(), - WanVideoUnit_NoiseInitializer(), - WanVideoUnit_InputVideoEmbedder(), - WanVideoUnit_PromptEmbedder(), - WanVideoUnit_ImageEmbedder(), - WanVideoUnit_FunCamera(), - WanVideoUnit_FunControl(), - WanVideoUnit_FunReference(), - WanVideoUnit_SpeedControl(), - WanVideoUnit_VACE(), - WanVideoUnit_TeaCache(), - WanVideoUnit_CfgMerger(), - ] - self.model_fn = model_fn_wan_video - - - def load_lora(self, module, path, alpha=1): - loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) - lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) - loader.load(module, lora, alpha=alpha) - - - def training_loss(self, **inputs): - timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) - timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) - - inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) - training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) - - noise_pred = self.model_fn(**inputs, timestep=timestep) - - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) - loss = loss * self.scheduler.training_weight(timestep) - return loss - - - def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): - self.vram_management_enabled = True - if num_persistent_param_in_dit is not None: - vram_limit = None - else: - if vram_limit is None: - vram_limit = self.get_free_vram() - vram_limit = vram_limit - vram_buffer - if self.text_encoder is not None: - dtype = next(iter(self.text_encoder.parameters())).dtype - enable_vram_management( - self.text_encoder, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Embedding: AutoWrappedModule, - T5RelativeEmbedding: AutoWrappedModule, - T5LayerNorm: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - vram_limit=vram_limit, - ) - if self.dit is not None: - dtype = next(iter(self.dit.parameters())).dtype - device = "cpu" if vram_limit is not None else self.device - enable_vram_management( - self.dit, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Conv3d: AutoWrappedModule, - torch.nn.LayerNorm: WanAutoCastLayerNorm, - RMSNorm: AutoWrappedModule, - torch.nn.Conv2d: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device=device, - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - max_num_param=num_persistent_param_in_dit, - overflow_module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - vram_limit=vram_limit, - ) - if self.vae is not None: - dtype = next(iter(self.vae.parameters())).dtype - enable_vram_management( - self.vae, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Conv2d: AutoWrappedModule, - RMS_norm: AutoWrappedModule, - CausalConv3d: AutoWrappedModule, - Upsample: AutoWrappedModule, - torch.nn.SiLU: AutoWrappedModule, - torch.nn.Dropout: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device=self.device, - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - ) - if self.image_encoder is not None: - dtype = next(iter(self.image_encoder.parameters())).dtype - enable_vram_management( - self.image_encoder, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Conv2d: AutoWrappedModule, - torch.nn.LayerNorm: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=dtype, - computation_device=self.device, - ), - ) - if self.motion_controller is not None: - dtype = next(iter(self.motion_controller.parameters())).dtype - enable_vram_management( - self.motion_controller, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=dtype, - computation_device=self.device, - ), - ) - if self.vace is not None: - device = "cpu" if vram_limit is not None else self.device - enable_vram_management( - self.vace, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Conv3d: AutoWrappedModule, - torch.nn.LayerNorm: AutoWrappedModule, - RMSNorm: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device=device, - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - vram_limit=vram_limit, - ) - - - @staticmethod - def from_pretrained( - torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", - model_configs: list[ModelConfig] = [], - tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), - local_model_path: str = "./models", - skip_download: bool = False, - redirect_common_files: bool = True, - ): - # Redirect model path - if redirect_common_files: - redirect_dict = { - "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", - "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", - "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", - } - for model_config in model_configs: - if model_config.origin_file_pattern is None or model_config.model_id is None: - continue - if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]: - print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") - model_config.model_id = redirect_dict[model_config.origin_file_pattern] - - # Download and load models - model_manager = ModelManager() - for model_config in model_configs: - model_config.download_if_necessary(local_model_path, skip_download=skip_download) - model_manager.load_model( - model_config.path, - device=model_config.offload_device or device, - torch_dtype=model_config.offload_dtype or torch_dtype - ) - - # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) - pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") - pipe.dit = model_manager.fetch_model("wan_video_dit") - pipe.vae = model_manager.fetch_model("wan_video_vae") - pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") - pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") - pipe.vace = model_manager.fetch_model("wan_video_vace") - - # Initialize tokenizer - tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download) - pipe.prompter.fetch_models(pipe.text_encoder) - pipe.prompter.fetch_tokenizer(tokenizer_config.path) - return pipe - - - @torch.no_grad() - def __call__( - self, - # Prompt - prompt: str, - negative_prompt: Optional[str] = "", - # Image-to-video - input_image: Optional[Image.Image] = None, - # First-last-frame-to-video - end_image: Optional[Image.Image] = None, - # Video-to-video - input_video: Optional[list[Image.Image]] = None, - denoising_strength: Optional[float] = 1.0, - # ControlNet - control_video: Optional[list[Image.Image]] = None, - reference_image: Optional[Image.Image] = None, - # VACE - vace_video: Optional[list[Image.Image]] = None, - vace_video_mask: Optional[Image.Image] = None, - vace_reference_image: Optional[Image.Image] = None, - vace_scale: Optional[float] = 1.0, - # Randomness - seed: Optional[int] = None, - rand_device: Optional[str] = "cpu", - # Shape - height: Optional[int] = 480, - width: Optional[int] = 832, - num_frames=81, - # Classifier-free guidance - cfg_scale: Optional[float] = 5.0, - cfg_merge: Optional[bool] = False, - # Scheduler - num_inference_steps: Optional[int] = 50, - sigma_shift: Optional[float] = 5.0, - # Speed control - motion_bucket_id: Optional[int] = None, - # VAE tiling - tiled: Optional[bool] = True, - tile_size: Optional[tuple[int, int]] = (30, 52), - tile_stride: Optional[tuple[int, int]] = (15, 26), - # Sliding window - sliding_window_size: Optional[int] = None, - sliding_window_stride: Optional[int] = None, - # Teacache - tea_cache_l1_thresh: Optional[float] = None, - tea_cache_model_id: Optional[str] = "", - # progress_bar - progress_bar_cmd=tqdm, - # Camera control - control_camera_video: Optional[torch.Tensor] = None - ): - # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) - - # Inputs - inputs_posi = { - "prompt": prompt, - "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, - } - inputs_nega = { - "negative_prompt": negative_prompt, - "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, - } - inputs_shared = { - "input_image": input_image, - "end_image": end_image, - "input_video": input_video, "denoising_strength": denoising_strength, - "control_video": control_video, "reference_image": reference_image, - "control_camera_video": control_camera_video, - "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, - "seed": seed, "rand_device": rand_device, - "height": height, "width": width, "num_frames": num_frames, - "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, - "num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift, - "motion_bucket_id": motion_bucket_id, - "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, - "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, - } - for unit in self.units: - inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) - - # Denoise - self.load_models_to_device(self.in_iteration_models) - models = {name: getattr(self, name) for name in self.in_iteration_models} - for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - - # Inference - noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) - if cfg_scale != 1.0: - if cfg_merge: - noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) - else: - noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) - else: - noise_pred = noise_pred_posi - - # Scheduler - inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) - - # VACE (TODO: remove it) - if vace_reference_image is not None: - inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] - - # Decode - self.load_models_to_device(['vae']) - video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - video = self.vae_output_to_video(video) - self.load_models_to_device([]) - - return video - - - -class PipelineUnit: - def __init__( - self, - seperate_cfg: bool = False, - take_over: bool = False, - input_params: tuple[str] = None, - input_params_posi: dict[str, str] = None, - input_params_nega: dict[str, str] = None, - onload_model_names: tuple[str] = None - ): - self.seperate_cfg = seperate_cfg - self.take_over = take_over - self.input_params = input_params - self.input_params_posi = input_params_posi - self.input_params_nega = input_params_nega - self.onload_model_names = onload_model_names - - - def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict: - raise NotImplementedError("`process` is not implemented.") - - - -class PipelineUnitRunner: - def __init__(self): - pass - - def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: - if unit.take_over: - # Let the pipeline unit take over this function. - inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) - elif unit.seperate_cfg: - # Positive side - processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} - processor_outputs = unit.process(pipe, **processor_inputs) - inputs_posi.update(processor_outputs) - # Negative side - if inputs_shared["cfg_scale"] != 1: - processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} - processor_outputs = unit.process(pipe, **processor_inputs) - inputs_nega.update(processor_outputs) - else: - inputs_nega.update(processor_outputs) - else: - processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} - processor_outputs = unit.process(pipe, **processor_inputs) - inputs_shared.update(processor_outputs) - return inputs_shared, inputs_posi, inputs_nega - - - -class WanVideoUnit_ShapeChecker(PipelineUnit): - def __init__(self): - super().__init__(input_params=("height", "width", "num_frames")) - - def process(self, pipe: WanVideoPipeline, height, width, num_frames): - height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) - return {"height": height, "width": width, "num_frames": num_frames} - - - -class WanVideoUnit_NoiseInitializer(PipelineUnit): - def __init__(self): - super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image")) - - def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): - length = (num_frames - 1) // 4 + 1 - if vace_reference_image is not None: - length += 1 - noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device) - if vace_reference_image is not None: - noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) - return {"noise": noise} - - - -class WanVideoUnit_InputVideoEmbedder(PipelineUnit): - def __init__(self): - super().__init__( - input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"), - onload_model_names=("vae",) - ) - - def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength): - if input_video is None: - return {"latents": noise} - pipe.load_models_to_device(["vae"]) - input_video = pipe.preprocess_video(input_video) - input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - if pipe.scheduler.training: - return {"latents": noise, "input_latents": input_latents} - else: - latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) - return {"latents": latents} - - - -class WanVideoUnit_PromptEmbedder(PipelineUnit): - def __init__(self): - super().__init__( - seperate_cfg=True, - input_params_posi={"prompt": "prompt", "positive": "positive"}, - input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, - onload_model_names=("text_encoder",) - ) - - def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: - pipe.load_models_to_device(self.onload_model_names) - prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device) - return {"context": prompt_emb} - - - -class WanVideoUnit_ImageEmbedder(PipelineUnit): - def __init__(self): - super().__init__( - input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "control_camera_video","latents"), - onload_model_names=("image_encoder", "vae") - ) - - def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride, control_camera_video,latents): - if input_image is None: - return {} - - pipe.load_models_to_device(self.onload_model_names) - image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) - - clip_context = pipe.image_encoder.encode_image([image]) - msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) - msk[:, 1:] = 0 - if end_image is not None: - end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) - vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) - if pipe.dit.has_image_pos_emb: - clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) - msk[:, -1:] = 1 - else: - vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) - - msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) - msk = msk.transpose(1, 2)[0] - - y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] - y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - y = y.unsqueeze(0) - clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) - y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - return {"clip_feature": clip_context, "y": y} - - - -class WanVideoUnit_FunControl(PipelineUnit): - def __init__(self): - super().__init__( - input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), - onload_model_names=("vae") - ) - - def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): - if control_video is None: - return {} - pipe.load_models_to_device(self.onload_model_names) - control_video = pipe.preprocess_video(control_video) - control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) - if clip_feature is None or y is None: - clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) - y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) - else: - y = y[:, -16:] - y = torch.concat([control_latents, y], dim=1) - return {"clip_feature": clip_feature, "y": y} - - - -class WanVideoUnit_FunReference(PipelineUnit): - def __init__(self): - super().__init__( - input_params=("reference_image", "height", "width", "reference_image"), - onload_model_names=("vae") - ) - - def process(self, pipe: WanVideoPipeline, reference_image, height, width): - if reference_image is None: - return {} - pipe.load_models_to_device(["vae"]) - reference_image = reference_image.resize((width, height)) - reference_latents = pipe.preprocess_video([reference_image]) - reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) - clip_feature = pipe.preprocess_image(reference_image) - clip_feature = pipe.image_encoder.encode_image([clip_feature]) - return {"reference_latents": reference_latents, "clip_feature": clip_feature} - -class WanVideoUnit_FunCamera(PipelineUnit): - def __init__(self): - super().__init__( - input_params=("control_camera_video", "cfg_merge", "num_frames", "height", "width", "input_image", "latents"), - onload_model_names=("vae") - ) - - def process(self, pipe: WanVideoPipeline, control_camera_video, cfg_merge, num_frames, height, width, input_image, latents): - if control_camera_video is None: - return {} - control_camera_video = control_camera_video[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) - control_camera_latents = torch.concat( - [ - torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), - control_camera_video[:, :, 1:] - ], dim=2 - ).transpose(1, 2) - b, f, c, h, w = control_camera_latents.shape - control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) - control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) - control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) - - input_image = input_image.resize((width, height)) - input_latents = pipe.preprocess_video([input_image]) - input_latents = pipe.vae.encode(input_latents, device=pipe.device) - y = torch.zeros_like(latents).to(pipe.device) - if latents.size()[2] != 1: - y[:, :, :1] = input_latents - y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - - return {"control_camera_latents": control_camera_latents, "control_camera_latents_input": control_camera_latents_input, "y":y} - - -class WanVideoUnit_SpeedControl(PipelineUnit): - def __init__(self): - super().__init__(input_params=("motion_bucket_id",)) - - def process(self, pipe: WanVideoPipeline, motion_bucket_id): - if motion_bucket_id is None: - return {} - motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) - return {"motion_bucket_id": motion_bucket_id} - - - -class WanVideoUnit_VACE(PipelineUnit): - def __init__(self): - super().__init__( - input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae",) - ) - - def process( - self, - pipe: WanVideoPipeline, - vace_video, vace_mask, vace_reference_image, vace_scale, - height, width, num_frames, - tiled, tile_size, tile_stride - ): - if vace_video is not None or vace_mask is not None or vace_reference_image is not None: - pipe.load_models_to_device(["vae"]) - if vace_video is None: - vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) - else: - vace_video = pipe.preprocess_video(vace_video) - - if vace_mask is None: - vace_mask = torch.ones_like(vace_video) - else: - vace_mask = pipe.preprocess_video(vace_mask) - - inactive = vace_video * (1 - vace_mask) + 0 * vace_mask - reactive = vace_video * vace_mask + 0 * (1 - vace_mask) - inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - vace_video_latents = torch.concat((inactive, reactive), dim=1) - - vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) - vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') - - if vace_reference_image is None: - pass - else: - vace_reference_image = pipe.preprocess_video([vace_reference_image]) - vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) - vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) - vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) - - vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) - return {"vace_context": vace_context, "vace_scale": vace_scale} - else: - return {"vace_context": None, "vace_scale": vace_scale} - - - -class WanVideoUnit_TeaCache(PipelineUnit): - def __init__(self): - super().__init__( - seperate_cfg=True, - input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, - input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, - ) - - def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): - if tea_cache_l1_thresh is None: - return {} - return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} - - - -class WanVideoUnit_CfgMerger(PipelineUnit): - def __init__(self): - super().__init__(take_over=True) - self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] - - def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): - if not inputs_shared["cfg_merge"]: - return inputs_shared, inputs_posi, inputs_nega - for name in self.concat_tensor_names: - tensor_posi = inputs_posi.get(name) - tensor_nega = inputs_nega.get(name) - tensor_shared = inputs_shared.get(name) - if tensor_posi is not None and tensor_nega is not None: - inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) - elif tensor_shared is not None: - inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) - inputs_posi.clear() - inputs_nega.clear() - return inputs_shared, inputs_posi, inputs_nega - - - -class TeaCache: - def __init__(self, num_inference_steps, rel_l1_thresh, model_id): - self.num_inference_steps = num_inference_steps - self.step = 0 - self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = None - self.rel_l1_thresh = rel_l1_thresh - self.previous_residual = None - self.previous_hidden_states = None - - self.coefficients_dict = { - "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], - "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], - "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], - "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], - } - if model_id not in self.coefficients_dict: - supported_model_ids = ", ".join([i for i in self.coefficients_dict]) - raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") - self.coefficients = self.coefficients_dict[model_id] - - def check(self, dit: WanModel, x, t_mod): - modulated_inp = t_mod.clone() - if self.step == 0 or self.step == self.num_inference_steps - 1: - should_calc = True - self.accumulated_rel_l1_distance = 0 - else: - coefficients = self.coefficients - rescale_func = np.poly1d(coefficients) - self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) - if self.accumulated_rel_l1_distance < self.rel_l1_thresh: - should_calc = False - else: - should_calc = True - self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = modulated_inp - self.step += 1 - if self.step == self.num_inference_steps: - self.step = 0 - if should_calc: - self.previous_hidden_states = x.clone() - return not should_calc - - def store(self, hidden_states): - self.previous_residual = hidden_states - self.previous_hidden_states - self.previous_hidden_states = None - - def update(self, hidden_states): - hidden_states = hidden_states + self.previous_residual - return hidden_states - - - -class TemporalTiler_BCTHW: - def __init__(self): - pass - - def build_1d_mask(self, length, left_bound, right_bound, border_width): - x = torch.ones((length,)) - if not left_bound: - x[:border_width] = (torch.arange(border_width) + 1) / border_width - if not right_bound: - x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) - return x - - def build_mask(self, data, is_bound, border_width): - _, _, T, _, _ = data.shape - t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) - mask = repeat(t, "T -> 1 1 T 1 1") - return mask - - def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): - tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] - tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} - B, C, T, H, W = tensor_dict[tensor_names[0]].shape - if batch_size is not None: - B *= batch_size - data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype - value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) - weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) - for t in range(0, T, sliding_window_stride): - if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: - continue - t_ = min(t + sliding_window_size, T) - model_kwargs.update({ - tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ - for tensor_name in tensor_names - }) - model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) - mask = self.build_mask( - model_output, - is_bound=(t == 0, t_ == T), - border_width=(sliding_window_size - sliding_window_stride,) - ).to(device=data_device, dtype=data_dtype) - value[:, :, t: t_, :, :] += model_output * mask - weight[:, :, t: t_, :, :] += mask - value /= weight - model_kwargs.update(tensor_dict) - return value - - - -def model_fn_wan_video( - dit: WanModel, - motion_controller: WanMotionControllerModel = None, - vace: VaceWanModel = None, - latents: torch.Tensor = None, - timestep: torch.Tensor = None, - context: torch.Tensor = None, - clip_feature: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, - reference_latents = None, - vace_context = None, - vace_scale = 1.0, - tea_cache: TeaCache = None, - use_unified_sequence_parallel: bool = False, - motion_bucket_id: Optional[torch.Tensor] = None, - sliding_window_size: Optional[int] = None, - sliding_window_stride: Optional[int] = None, - cfg_merge: bool = False, - use_gradient_checkpointing: bool = False, - use_gradient_checkpointing_offload: bool = False, - control_camera_latents = None, - control_camera_latents_input = None, - **kwargs, -): - if sliding_window_size is not None and sliding_window_stride is not None: - model_kwargs = dict( - dit=dit, - motion_controller=motion_controller, - vace=vace, - latents=latents, - timestep=timestep, - context=context, - clip_feature=clip_feature, - y=y, - reference_latents=reference_latents, - vace_context=vace_context, - vace_scale=vace_scale, - tea_cache=tea_cache, - use_unified_sequence_parallel=use_unified_sequence_parallel, - motion_bucket_id=motion_bucket_id, - ) - return TemporalTiler_BCTHW().run( - model_fn_wan_video, - sliding_window_size, sliding_window_stride, - latents.device, latents.dtype, - model_kwargs=model_kwargs, - tensor_names=["latents", "y"], - batch_size=2 if cfg_merge else 1 - ) - - if use_unified_sequence_parallel: - import torch.distributed as dist - from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) - - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) - t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) - if motion_bucket_id is not None and motion_controller is not None: - t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) - context = dit.text_embedding(context) - - x = latents - # Merged cfg - if x.shape[0] != context.shape[0]: - x = torch.concat([x] * context.shape[0], dim=0) - if timestep.shape[0] != context.shape[0]: - timestep = torch.concat([timestep] * context.shape[0], dim=0) - - if dit.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) - clip_embdding = dit.img_emb(clip_feature) - context = torch.cat([clip_embdding, context], dim=1) - - # Add camera control - x, (f, h, w) = dit.patchify(x, control_camera_latents_input) - - # Reference image - if reference_latents is not None: - if len(reference_latents.shape) == 5: - reference_latents = reference_latents[:, :, 0] - reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) - x = torch.concat([reference_latents, x], dim=1) - f += 1 - - freqs = torch.cat([ - dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) - - # TeaCache - if tea_cache is not None: - tea_cache_update = tea_cache.check(dit, x, t_mod) - else: - tea_cache_update = False - - if vace_context is not None: - vace_hints = vace(x, vace_context, context, t_mod, freqs) - - # blocks - if use_unified_sequence_parallel: - if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] - if tea_cache_update: - x = tea_cache.update(x) - else: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - - for block_id, block in enumerate(dit.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, freqs) - if vace_context is not None and block_id in vace.vace_layers_mapping: - x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale - if tea_cache is not None: - tea_cache.store(x) - - if reference_latents is not None: - x = x[:, reference_latents.shape[1]:] - f -= 1 - - x = dit.head(x, t) - if use_unified_sequence_parallel: - if dist.is_initialized() and dist.get_world_size() > 1: - x = get_sp_group().all_gather(x, dim=1) - x = dit.unpatchify(x, (f, h, w)) - return x diff --git a/diffsynth/processors/__pycache__/__init__.cpython-310.pyc b/diffsynth/processors/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index cb65091..0000000 Binary files a/diffsynth/processors/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/processors/__pycache__/base.cpython-310.pyc b/diffsynth/processors/__pycache__/base.cpython-310.pyc deleted file mode 100644 index dc65819..0000000 Binary files a/diffsynth/processors/__pycache__/base.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/processors/__pycache__/sequencial_processor.cpython-310.pyc b/diffsynth/processors/__pycache__/sequencial_processor.cpython-310.pyc deleted file mode 100644 index 321942d..0000000 Binary files a/diffsynth/processors/__pycache__/sequencial_processor.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/__init__.cpython-310.pyc b/diffsynth/prompters/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 4154d84..0000000 Binary files a/diffsynth/prompters/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/base_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/base_prompter.cpython-310.pyc deleted file mode 100644 index 902acfc..0000000 Binary files a/diffsynth/prompters/__pycache__/base_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/cog_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/cog_prompter.cpython-310.pyc deleted file mode 100644 index d5f7376..0000000 Binary files a/diffsynth/prompters/__pycache__/cog_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/flux_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/flux_prompter.cpython-310.pyc deleted file mode 100644 index 68e9b53..0000000 Binary files a/diffsynth/prompters/__pycache__/flux_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-310.pyc deleted file mode 100644 index 2a72d65..0000000 Binary files a/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-310.pyc deleted file mode 100644 index 536be36..0000000 Binary files a/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/kolors_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/kolors_prompter.cpython-310.pyc deleted file mode 100644 index c2efea5..0000000 Binary files a/diffsynth/prompters/__pycache__/kolors_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-310.pyc deleted file mode 100644 index 32e8e19..0000000 Binary files a/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/omost.cpython-310.pyc b/diffsynth/prompters/__pycache__/omost.cpython-310.pyc deleted file mode 100644 index 51e8053..0000000 Binary files a/diffsynth/prompters/__pycache__/omost.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/prompt_refiners.cpython-310.pyc b/diffsynth/prompters/__pycache__/prompt_refiners.cpython-310.pyc deleted file mode 100644 index bec1243..0000000 Binary files a/diffsynth/prompters/__pycache__/prompt_refiners.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/sd3_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/sd3_prompter.cpython-310.pyc deleted file mode 100644 index 27bf0f1..0000000 Binary files a/diffsynth/prompters/__pycache__/sd3_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/sd_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/sd_prompter.cpython-310.pyc deleted file mode 100644 index 6b0f45f..0000000 Binary files a/diffsynth/prompters/__pycache__/sd_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-310.pyc deleted file mode 100644 index d2f0752..0000000 Binary files a/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-310.pyc deleted file mode 100644 index 035fc47..0000000 Binary files a/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/prompters/__pycache__/wan_prompter.cpython-310.pyc b/diffsynth/prompters/__pycache__/wan_prompter.cpython-310.pyc deleted file mode 100644 index d10035c..0000000 Binary files a/diffsynth/prompters/__pycache__/wan_prompter.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/schedulers/__pycache__/__init__.cpython-310.pyc b/diffsynth/schedulers/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 0c9872a..0000000 Binary files a/diffsynth/schedulers/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/schedulers/__pycache__/continuous_ode.cpython-310.pyc b/diffsynth/schedulers/__pycache__/continuous_ode.cpython-310.pyc deleted file mode 100644 index e21f510..0000000 Binary files a/diffsynth/schedulers/__pycache__/continuous_ode.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/schedulers/__pycache__/ddim.cpython-310.pyc b/diffsynth/schedulers/__pycache__/ddim.cpython-310.pyc deleted file mode 100644 index ea9a353..0000000 Binary files a/diffsynth/schedulers/__pycache__/ddim.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/schedulers/__pycache__/flow_match.cpython-310.pyc b/diffsynth/schedulers/__pycache__/flow_match.cpython-310.pyc deleted file mode 100644 index 884c581..0000000 Binary files a/diffsynth/schedulers/__pycache__/flow_match.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 9754b98..d6d0219 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -35,9 +35,6 @@ class FlowMatchScheduler(): y_shifted = y - y.min() bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing - self.training = True - else: - self.training = False def step(self, model_output, timestep, sample, to_final=False, **kwargs): diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py deleted file mode 100644 index 0a056a6..0000000 --- a/diffsynth/trainers/utils.py +++ /dev/null @@ -1,257 +0,0 @@ -import imageio, os, torch, warnings, torchvision, argparse -from peft import LoraConfig, inject_adapter_in_model -from PIL import Image -import pandas as pd -from tqdm import tqdm -from accelerate import Accelerator - - - -class VideoDataset(torch.utils.data.Dataset): - def __init__( - self, - base_path=None, metadata_path=None, - frame_interval=1, num_frames=81, - dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None, - height_division_factor=16, width_division_factor=16, - data_file_keys=("video",), - image_file_extension=("jpg", "jpeg", "png", "webp"), - video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), - repeat=1, - args=None, - ): - if args is not None: - base_path = args.dataset_base_path - metadata_path = args.dataset_metadata_path - height = args.height - width = args.width - num_frames = args.num_frames - data_file_keys = args.data_file_keys.split(",") - repeat = args.dataset_repeat - - metadata = pd.read_csv(metadata_path) - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] - - self.base_path = base_path - self.frame_interval = frame_interval - self.num_frames = num_frames - self.dynamic_resolution = dynamic_resolution - self.max_pixels = max_pixels - self.height = height - self.width = width - self.height_division_factor = height_division_factor - self.width_division_factor = width_division_factor - self.data_file_keys = data_file_keys - self.image_file_extension = image_file_extension - self.video_file_extension = video_file_extension - self.repeat = repeat - - if height is not None and width is not None and dynamic_resolution == True: - print("Height and width are fixed. Setting `dynamic_resolution` to False.") - self.dynamic_resolution = False - - - def crop_and_resize(self, image, target_height, target_width): - width, height = image.size - scale = max(target_width / width, target_height / height) - image = torchvision.transforms.functional.resize( - image, - (round(height*scale), round(width*scale)), - interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) - return image - - - def get_height_width(self, image): - if self.dynamic_resolution: - width, height = image.size - if width * height > self.max_pixels: - scale = (width * height / self.max_pixels) ** 0.5 - height, width = int(height / scale), int(width / scale) - height = height // self.height_division_factor * self.height_division_factor - width = width // self.width_division_factor * self.width_division_factor - else: - height, width = self.height, self.width - return height, width - - - def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames): - reader = imageio.get_reader(file_path) - if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: - reader.close() - return None - frames = [] - for frame_id in range(num_frames): - frame = reader.get_data(start_frame_id + frame_id * interval) - frame = Image.fromarray(frame) - frame = self.crop_and_resize(frame, *self.get_height_width(frame)) - frames.append(frame) - reader.close() - return frames - - - def load_image(self, file_path): - image = Image.open(file_path).convert("RGB") - image = self.crop_and_resize(image, *self.get_height_width(image)) - return image - - - def load_video(self, file_path): - frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames) - return frames - - - def is_image(self, file_path): - file_ext_name = file_path.split(".")[-1] - return file_ext_name.lower() in self.image_file_extension - - - def is_video(self, file_path): - file_ext_name = file_path.split(".")[-1] - return file_ext_name.lower() in self.video_file_extension - - - def load_data(self, file_path): - if self.is_image(file_path): - return self.load_image(file_path) - elif self.is_video(file_path): - return self.load_video(file_path) - else: - return None - - - def __getitem__(self, data_id): - data = self.data[data_id % len(self.data)].copy() - for key in self.data_file_keys: - if key in data: - path = os.path.join(self.base_path, data[key]) - data[key] = self.load_data(path) - if data[key] is None: - warnings.warn(f"cannot load file {data[key]}.") - return None - return data - - - def __len__(self): - return len(self.data) * self.repeat - - - -class DiffusionTrainingModule(torch.nn.Module): - def __init__(self): - super().__init__() - - - def to(self, *args, **kwargs): - for name, model in self.named_children(): - model.to(*args, **kwargs) - return self - - - def trainable_modules(self): - trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) - return trainable_modules - - - def trainable_param_names(self): - trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) - trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) - return trainable_param_names - - - def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): - if lora_alpha is None: - lora_alpha = lora_rank - lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) - model = inject_adapter_in_model(lora_config, model) - return model - - - def export_trainable_state_dict(self, state_dict, remove_prefix=None): - trainable_param_names = self.trainable_param_names() - state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} - if remove_prefix is not None: - state_dict_ = {} - for name, param in state_dict.items(): - if name.startswith(remove_prefix): - name = name[len(remove_prefix):] - state_dict_[name] = param - state_dict = state_dict_ - return state_dict - - - -def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None): - if args is not None: - learning_rate = args.learning_rate - num_epochs = args.num_epochs - output_path = args.output_path - remove_prefix_in_ckpt = args.remove_prefix_in_ckpt - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - - accelerator = Accelerator(gradient_accumulation_steps=1) - model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) - - for epoch in range(num_epochs): - for data in tqdm(dataloader): - with accelerator.accumulate(model): - optimizer.zero_grad() - loss = model(data) - accelerator.backward(loss) - optimizer.step() - scheduler.step() - accelerator.wait_for_everyone() - if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) - state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt) - os.makedirs(output_path, exist_ok=True) - path = os.path.join(output_path, f"epoch-{epoch}.safetensors") - accelerator.save(state_dict, path, safe_serialization=True) - - - -def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0]) - accelerator = Accelerator() - model, dataloader = accelerator.prepare(model, dataloader) - os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) - for data_id, data in enumerate(tqdm(dataloader)): - with torch.no_grad(): - inputs = model.forward_preprocess(data) - inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs} - torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")) - - - -def wan_parser(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.") - parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.") - parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.") - parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.") - parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in each video. The frames are sampled from the prefix.") - parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.") - parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.") - parser.add_argument("--model_paths", type=str, default=None, help="Model paths to be loaded. JSON format.") - parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas.") - parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") - parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") - parser.add_argument("--output_path", type=str, default="./models", help="Save path.") - parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") - parser.add_argument("--trainable_models", type=str, default=None, help="Trainable models, e.g., dit, vae, text_encoder.") - parser.add_argument("--lora_base_model", type=str, default=None, help="Add LoRA on which model.") - parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Add LoRA on which layer.") - parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank.") - parser.add_argument("--input_contains_input_image", default=False, action="store_true", help="Model input contains 'input_image'.") - parser.add_argument("--input_contains_end_image", default=False, action="store_true", help="Model input contains 'end_image'.") - parser.add_argument("--input_contains_control_video", default=False, action="store_true", help="Model input contains 'control_video'.") - parser.add_argument("--input_contains_reference_image", default=False, action="store_true", help="Model input contains 'reference_image'.") - parser.add_argument("--input_contains_vace_video", default=False, action="store_true", help="Model input contains 'vace_video'.") - parser.add_argument("--input_contains_vace_reference_image", default=False, action="store_true", help="Model input contains 'vace_reference_image'.") - parser.add_argument("--input_contains_motion_bucket_id", default=False, action="store_true", help="Model input contains 'motion_bucket_id'.") - parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Offload gradient checkpointing to RAM.") - return parser - diff --git a/diffsynth/vram_management/__pycache__/__init__.cpython-310.pyc b/diffsynth/vram_management/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 74c52a7..0000000 Binary files a/diffsynth/vram_management/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/vram_management/__pycache__/layers.cpython-310.pyc b/diffsynth/vram_management/__pycache__/layers.cpython-310.pyc deleted file mode 100644 index 62ff84a..0000000 Binary files a/diffsynth/vram_management/__pycache__/layers.cpython-310.pyc and /dev/null differ diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index dd4a245..a9df39e 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -8,32 +8,8 @@ def cast_to(weight, dtype, device): return r -class AutoTorchModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def check_free_vram(self): - used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3) - return used_memory < self.vram_limit - - def offload(self): - if self.state != 0: - self.to(dtype=self.offload_dtype, device=self.offload_device) - self.state = 0 - - def onload(self): - if self.state != 1: - self.to(dtype=self.onload_dtype, device=self.onload_device) - self.state = 1 - - def keep(self): - if self.state != 2: - self.to(dtype=self.computation_dtype, device=self.computation_device) - self.state = 2 - - -class AutoWrappedModule(AutoTorchModule): - def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs): +class AutoWrappedModule(torch.nn.Module): + def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): super().__init__() self.module = module.to(dtype=offload_dtype, device=offload_device) self.offload_dtype = offload_dtype @@ -42,57 +18,28 @@ class AutoWrappedModule(AutoTorchModule): self.onload_device = onload_device self.computation_dtype = computation_dtype self.computation_device = computation_device - self.vram_limit = vram_limit self.state = 0 + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + def forward(self, *args, **kwargs): - if self.state == 2: + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: module = self.module else: - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: - module = self.module - elif self.vram_limit is not None and self.check_free_vram(): - self.keep() - module = self.module - else: - module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) + module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) return module(*args, **kwargs) -class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule): - def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs): - with init_weights_on_device(device=torch.device("meta")): - super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) - self.weight = module.weight - self.bias = module.bias - self.offload_dtype = offload_dtype - self.offload_device = offload_device - self.onload_dtype = onload_dtype - self.onload_device = onload_device - self.computation_dtype = computation_dtype - self.computation_device = computation_device - self.vram_limit = vram_limit - self.state = 0 - - def forward(self, x, *args, **kwargs): - if self.state == 2: - weight, bias = self.weight, self.bias - else: - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: - weight, bias = self.weight, self.bias - elif self.vram_limit is not None and self.check_free_vram(): - self.keep() - weight, bias = self.weight, self.bias - else: - weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device) - bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) - with torch.amp.autocast(device_type=x.device.type): - x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x) - return x - - -class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): - def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs): +class AutoWrappedLinear(torch.nn.Linear): + def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): with init_weights_on_device(device=torch.device("meta")): super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) self.weight = module.weight @@ -103,28 +50,29 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.onload_device = onload_device self.computation_dtype = computation_dtype self.computation_device = computation_device - self.vram_limit = vram_limit self.state = 0 - self.name = name + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 def forward(self, x, *args, **kwargs): - if self.state == 2: + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: weight, bias = self.weight, self.bias else: - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: - weight, bias = self.weight, self.bias - elif self.vram_limit is not None and self.check_free_vram(): - self.keep() - weight, bias = self.weight, self.bias - else: - weight = cast_to(self.weight, self.computation_dtype, self.computation_device) - bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + weight = cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) return torch.nn.functional.linear(x, weight, bias) -def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""): +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): for name, module in model.named_children(): - layer_name = name if name_prefix == "" else name_prefix + "." + name for source_module, target_module in module_map.items(): if isinstance(module, source_module): num_param = sum(p.numel() for p in module.parameters()) @@ -132,16 +80,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config_ = overflow_module_config else: module_config_ = module_config - module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name) + module_ = target_module(module, **module_config_) setattr(model, name, module_) total_num_param += num_param break else: - total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name) + total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) return total_num_param -def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None): - enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit) +def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): + enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) model.vram_management_enabled = True diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 46c9670..92c3c59 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -1,32 +1,276 @@ +# Wan-Video + +Wan-Video is a collection of video synthesis models open-sourced by Alibaba. + +Before using this model, please install DiffSynth-Studio from **source code**. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +## Model Zoo + +|Developer|Name|Link|Scripts| +|-|-|-|-| +|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)| +|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)| +|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| +|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| +|Wan Team|14B first-last-frame-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|[wan_14B_flf2v.py](./wan_14B_flf2v.py)| +|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).| +|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).| +|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).| +|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)| +|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| +|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| +|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)| +|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)| +|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)| + +Base model features + +||Text-to-video|Image-to-video|End frame|Control|Reference image| +|-|-|-|-|-|-| +|1.3B text-to-video|✅||||| +|14B text-to-video|✅||||| +|14B image-to-video 480P||✅|||| +|14B image-to-video 720P||✅|||| +|14B first-last-frame-to-video 720P||✅|✅||| +|1.3B InP||✅|✅||| +|14B InP||✅|✅||| +|1.3B Control||||✅|| +|14B Control||||✅|| +|1.3B VACE||||✅|✅| + +Adapter model compatibility + +||1.3B text-to-video|1.3B InP|1.3B VACE| +|-|-|-|-| +|1.3B aesthetics LoRA|✅||✅| +|1.3B Highres-fix LoRA|✅||✅| +|1.3B ExVideo LoRA|✅||✅| +|1.3B Speed Control adapter|✅|✅|✅| + +## VRAM Usage + +* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). + +* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!). + +We present a detailed table here. The model (14B text-to-video) is tested on a single A100. + +|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting| +|-|-|-|-|-| +|torch.bfloat16|None (unlimited)|18.5s/it|48G|| +|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G|| +|torch.bfloat16|0|23.4s/it|10G|| +|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes| +|torch.float8_e4m3fn|0|24.0s/it|10G|| + +**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.** + +## Efficient Attention Implementation + +DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA. + +* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) +* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) +* [Sage Attention](https://github.com/thu-ml/SageAttention) +* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) + +## Acceleration + +We support multiple acceleration solutions: +* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py). + +* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py) + +```bash +pip install xfuser>=0.4.3 +torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py +``` + +* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py). + +## Gallery + +1.3B text-to-video. + +https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8 + +Put sunglasses on the dog. + +https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb + +14B text-to-video. + +https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f + +14B image-to-video. + +https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 + +14B first-last-frame-to-video + +|First frame|Last frame|Video| +|-|-|-| +|![Image](https://github.com/user-attachments/assets/b0d8225b-aee0-4129-b8e5-58c8523221a6)|![Image](https://github.com/user-attachments/assets/2f0c9bc5-07e2-45fa-8320-53d63a4fd203)|https://github.com/user-attachments/assets/2a6a2681-622c-4512-b852-5f22e73830b1| + +## Train + +We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA: + +https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9 + +Step 1: Install additional packages + +``` +pip install peft lightning pandas +``` + +Step 2: Prepare your dataset + +You need to manage the training videos as follows: + +``` +data/example_dataset/ +├── metadata.csv +└── train + ├── video_00001.mp4 + └── image_00002.jpg +``` + +`metadata.csv`: + +``` +file_name,text +video_00001.mp4,"video description" +image_00002.jpg,"video description" +``` + +We support both images and videos. An image is treated as a single frame of video. + +Step 3: Data process + +```shell +CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ + --task data_process \ + --dataset_path data/example_dataset \ + --output_path ./models \ + --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" \ + --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" \ + --tiled \ + --num_frames 81 \ + --height 480 \ + --width 832 +``` + +After that, some cached files will be stored in the dataset folder. + +``` +data/example_dataset/ +├── metadata.csv +└── train + ├── video_00001.mp4 + ├── video_00001.mp4.tensors.pth + ├── video_00002.mp4 + └── video_00002.mp4.tensors.pth +``` + +Step 4: Train + +LoRA training: + +```shell +CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ + --task train \ + --train_architecture lora \ + --dataset_path data/example_dataset \ + --output_path ./models \ + --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ + --steps_per_epoch 500 \ + --max_epochs 10 \ + --learning_rate 1e-4 \ + --lora_rank 16 \ + --lora_alpha 16 \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --accumulate_grad_batches 1 \ + --use_gradient_checkpointing +``` + +Full training: + +```shell +CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ + --task train \ + --train_architecture full \ + --dataset_path data/example_dataset \ + --output_path ./models \ + --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ + --steps_per_epoch 500 \ + --max_epochs 10 \ + --learning_rate 1e-4 \ + --accumulate_grad_batches 1 \ + --use_gradient_checkpointing +``` + +If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`. + +If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`. + +For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`. + +Step 5: Test + +Test LoRA: + +```python +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -* dataset - * `--dataset_base_path`: Base path of the Dataset. - * `--dataset_metadata_path`: Metadata path of the Dataset. - * `--height`: Image or video height. Leave `height` and `width` None to enable dynamic resolution. - * `--width`: Image or video width. Leave `height` and `width` None to enable dynamic resolution. - * `--num_frames`: Number of frames in each video. The frames are sampled from the prefix. - * `--data_file_keys`: Data file keys in metadata. Separated by commas. - * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. -* Model - * `--model_paths`: Model paths to be loaded. JSON format. - * `--model_id_with_origin_paths`: Model ID with original path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas. -* Training - * `--learning_rate`: Learning rate. - * `--num_epochs`: Number of epochs. - * `--output_path`: Save path. - * `--remove_prefix_in_ckpt`: Remove prefix in ckpt. -* Trainable module - * `--trainable_models`: Trainable models, e.g., dit, vae, text_encoder. - * `--lora_base_model`: Add LoRA on which model. - * `--lora_target_modules`: Add LoRA on which layer. - * `--lora_rank`: LoRA rank. -* Extra model input - * `--input_contains_input_image`: Model input contains `input_image` - * `--input_contains_end_image`: Model input contains `end_image`. - * `--input_contains_control_video`: Model input contains `control_video`. - * `--input_contains_reference_image`: Model input contains `reference_image`. - * `--input_contains_vace_video`: Model input contains `vace_video`. - * `--input_contains_vace_reference_image`: Model input contains `vace_reference_image`. - * `--input_contains_motion_bucket_id`: Model input contains `motion_bucket_id`. +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") +model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", +]) +model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0) +pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) +video = pipe( + prompt="...", + negative_prompt="...", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video.mp4", fps=30, quality=5) +``` + +Test fine-tuned base model: + +```python +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") +model_manager.load_models([ + "models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", +]) +pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +video = pipe( + prompt="...", + negative_prompt="...", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video.mp4", fps=30, quality=5) +``` diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md deleted file mode 100644 index ee53785..0000000 --- a/examples/wanvideo/README_zh.md +++ /dev/null @@ -1,357 +0,0 @@ -# 通义万相 2.1(Wan 2.1) - -|模型 ID|类型|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| -|-|-|-|-|-|-|-|-| -|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|基础模型||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)| -|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|基础模型||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)| -|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|基础模型|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| -|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|基础模型|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| -|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| -|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| -|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|基础模型|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| -|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| -|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|基础模型|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|基础模型|`input_image`, `end_image`|||||| -|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|基础模型|`input_image`, `end_image`|||||| -|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|基础模型||||||| -|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|基础模型||||||| -|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/VACE-Wan2.1-1.3B-Preview.sh)|[code](./model_training/validate_full/VACE-Wan2.1-1.3B-Preview.py)|[code](./model_training/lora/VACE-Wan2.1-1.3B-Preview.sh)|[code](./model_training/validate_lora/VACE-Wan2.1-1.3B-Preview.py)| -|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|适配器|`vace_control_video`, `vace_reference_image`|||||| -|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|适配器|`vace_control_video`, `vace_reference_image`|||||| -|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|适配器|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| - -## 模型推理 - -以下部分将会帮助您理解我们的功能并编写推理代码。 - - -
- -加载模型 - -模型通过 `from_pretrained` 加载: - -```python -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), - ], -) -``` - -其中 `torch_dtype` 和 `device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径: - -* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id` 和 `origin_file_pattern`,例如 - -```python -ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors") -``` - -* 从本地文件路径加载模型。此时需要填写 `path`,例如 - -```python -ModelConfig(path="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") -``` - -对于从多个文件加载的单一模型,使用列表即可,例如 - -```python -ModelConfig(path=[ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", -]) -``` - -`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为: - -* `tokenizer_config`: Wan 模型的 tokenizer 路径,默认值为 `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`。 -* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。 -* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。 -* `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。 - -
- - -
- -显存管理 - -DiffSynth-Studio 为 Wan 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。 - -```python -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() -``` - -FP8 量化功能也是支持的: - -```python -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_dtype=torch.float8_e4m3fn), - ], -) -pipe.enable_vram_management() -``` - -FP8 量化和 offload 可同时开启: - -```python -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), - ], -) -pipe.enable_vram_management() -``` - -FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在 FP8 量化下会出现精度不足导致的画面模糊、撕裂、失真问题,请谨慎使用 FP8 量化。 - -`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况: - -* `vram_limit`: 显存占用量(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。 -* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。 -* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。 - -
- - -
- -输入参数 - -Pipeline 在推理阶段能够接收以下输入参数: - -* `prompt`: 提示词,描述画面中出现的内容。 -* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 -* `input_image`: 输入图片,适用于图生视频模型,例如 [`Wan-AI/Wan2.1-I2V-14B-480P`](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)、[`PAI/Wan2.1-Fun-1.3B-InP`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP),以及首尾帧模型,例如 [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P)。 -* `end_image`: 结尾帧,适用于首尾帧模型,例如 [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P)。 -* `input_video`: 输入视频,用于视频生视频,适用于任意 Wan 系列模型,需与参数 `denoising_strength` 配合使用。 -* `denoising_strength`: 去噪强度,范围为 [0, 1]。数值越小,生成的视频越接近 `input_video`。 -* `control_video`: 控制视频,适用于带控制能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)。 -* `reference_image`: 参考图片,适用于带参考图能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-V1.1-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)。 -* `vace_video`: VACE 模型的输入视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 -* `vace_video_mask`: VACE 模型的 mask 视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 -* `vace_reference_image`: VACE 模型的参考图片,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 -* `vace_scale`: VACE 模型对基础模型的影响程度,默认为1。数值越大,控制强度越高,但画面崩坏概率越大。 -* `seed`: 随机种子。默认为 `None`,即完全随机。 -* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 -* `height`: 帧高度,默认为 480。需设置为 16 的倍数,不满足时向上取整。 -* `width`: 帧宽度,默认为 832。需设置为 16 的倍数,不满足时向上取整。 -* `num_frames`: 帧数,默认为 81。需设置为 4 的倍数 + 1,不满足时向上取整,最小值为 1。 -* `cfg_scale`: Classifier-free guidance 机制的数值,默认为 5。数值越大,提示词的控制效果越强,但画面崩坏的概率越大。 -* `cfg_merge`: 是否合并 Classifier-free guidance 的两侧进行统一推理,默认为 `False`。该参数目前仅在基础的文生视频和图生视频模型上生效。 -* `num_inference_steps`: 推理次数,默认值为 50。 -* `sigma_shift`: Rectified Flow 理论中的参数,默认为 5。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的视频内容与训练数据存在差异。 -* `motion_bucket_id`: 运动幅度,范围为 [0, 100]。适用于速度控制模块,例如 [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1),数值越大,运动幅度越大。 -* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 -* `tile_size`: VAE 编解码阶段的分块大小,默认为 (30, 52),仅在 `tiled=True` 时生效。 -* `tile_stride`: VAE 编解码阶段的分块步长,默认为 (15, 26),仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 -* `sliding_window_size`: DiT 部分的滑动窗口大小。实验性功能,效果不稳定。 -* `sliding_window_stride`: DiT 部分的滑动窗口步长。实验性功能,效果不稳定。 -* `tea_cache_l1_thresh`: TeaCache 的阈值,数值越大,速度越快,画面质量越差。请注意,开启 TeaCache 后推理速度并非均匀,因此进度条上显示的剩余时间将会变得不准确。 -* `tea_cache_model_id`: TeaCache 的参数模板,可选 `"Wan2.1-T2V-1.3B"`、`Wan2.1-T2V-14B`、`Wan2.1-I2V-14B-480P`、`Wan2.1-I2V-14B-720P` 之一。 -* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 - -
- - -## 模型训练 - -Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。 - -
- -脚本参数 - -脚本包含以下参数: - -* 数据集 - * `--dataset_base_path`: 数据集的根路径。 - * `--dataset_metadata_path`: 数据集的元数据文件路径。 - * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 - * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 - * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 - * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 - * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 -* 模型 - * `--model_paths`: 要加载的模型路径。JSON 格式。 - * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 -* 训练 - * `--learning_rate`: 学习率。 - * `--num_epochs`: 轮数(Epoch)数量。 - * `--output_path`: 保存路径。 - * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 -* 可训练模块 - * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 - * `--lora_base_model`: LoRA 添加到哪个模型上。 - * `--lora_target_modules`: LoRA 添加到哪一层上。 - * `--lora_rank`: LoRA 的秩(Rank)。 -* 额外模型输入 - * `--input_contains_input_image`: 模型输入包含 `input_image` - * `--input_contains_end_image`: 模型输入包含 `end_image`。 - * `--input_contains_control_video`: 模型输入包含 `control_video`。 - * `--input_contains_reference_image`: 模型输入包含 `reference_image`。 - * `--input_contains_vace_video`: 模型输入包含 `vace_video`。 - * `--input_contains_vace_reference_image`: 模型输入包含 `vace_reference_image`。 - * `--input_contains_motion_bucket_id`: 模型输入包含 `motion_bucket_id`。 -* 显存管理 - * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 - -
- - -
- -Step 1: 准备数据集 - -数据集包含一系列文件,我们建议您这样组织数据集文件: - -``` -data/example_video_dataset/ -├── metadata.csv -├── video1.mp4 -└── video2.mp4 -``` - -其中 `video1.mp4`、`video2.mp4` 为训练用视频数据,`metadata.csv` 为元数据列表,例如 - -``` -video,prompt -video1.mp4,"from sunset to night, a small town, light, house, river" -video2.mp4,"a dog is running" -``` - -我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集: - -```shell -modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset -``` - -数据集支持视频和图片混合训练,支持的视频文件格式包括 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`,支持的图片格式包括 `"jpg", "jpeg", "png", "webp"`。 - -视频的尺寸可通过脚本参数 `--height`、`--width`、`--num_frames` 控制。在每个视频中,前 `num_frames` 帧会被用于训练,因此当视频长度不足 `num_frames` 帧时会报错,图片文件会被视为单帧视频。当 `--height` 和 `--width` 为空时将会开启动态分辨率,按照数据集中每个视频或图片的实际宽高训练。 - -**我们强烈建议使用固定分辨率训练,并避免图像和视频混合训练,因为在多卡训练中存在负载均衡问题。** - -当模型需要额外输入时,例如具备控制能力的模型 [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) 所需的 `control_video`,请在数据集中补充相应的列,例如: - -``` -video,prompt,control_video -video1.mp4,"from sunset to night, a small town, light, house, river",video1_softedge.mp4 -``` - -额外输入若包含视频和图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。该参数的默认值为 `"image,video"`,即解析列名为 `image` 和 `video` 的列。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,video,control_video"`,同时启用 `--input_contains_control_video`。 - -
- - -
- -Step 2: 加载模型 - -类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型 - -```python -model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), -] -``` - -那么在训练时,填入以下参数即可加载对应的模型。 - -```shell ---model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" -``` - -如果您希望从本地文件加载模型,例如推理时 - -```python -model_configs=[ - ModelConfig(path=[ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", - ]), - ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth"), - ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"), -] -``` - -那么训练时需设置为 - -```shell ---model_paths '[ - [ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors" - ], - "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth" -]' \ -``` - -
- - -
- -Step 3: 设置可训练模块 - -训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子: - -* 全量训练 DiT 部分:`--trainable_models dit` -* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` -* 训练 DiT 部分的 LoRA 和 Motion Controller 部分(是的,可以训练这种花里胡哨的结构):`--trainable_models motion_controller --lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` - -此外,由于训练脚本中加载了多个模块(text encoder、dit、vae),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.` - -
- - -
- -Step 4: 启动训练程序 - -我们为每一个模型编写了训练命令,请参考本文档开头的表格。 - -请注意,14B 模型全量训练需要8个GPU,每个GPU的显存至少为80G。全量训练这些14B模型时需要安装 `deepspeed`(`pip install deepspeed`),我们编写了建议的[配置文件](./model_training/full/accelerate_config_14B.yaml),这个配置文件会在对应的训练脚本中被加载,这些脚本已在 8*A100 上测试过。 - -训练脚本的默认视频尺寸为 `480*832*81`,提升分辨率将可能导致显存不足,请添加参数 `--use_gradient_checkpointing_offload` 降低显存占用。 - -
diff --git a/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py deleted file mode 100644 index 3061398..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] -) - -# First and last frame to video -video = pipe( - prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), - end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), - seed=0, tiled=True, - height=960, width=960, num_frames=33, - sigma_shift=16, -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py deleted file mode 100644 index 43374d2..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/control_video.mp4" -) - -# Control video -control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) -video = pipe( - prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=control_video, height=832, width=576, num_frames=49, - seed=1, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py deleted file mode 100644 index d921c0c..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# First and last frame to video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=image, - seed=0, tiled=True - # You can input `end_image=xxx` to control the last frame of the video. - # The model will automatically generate the dynamic content between `input_image` and `end_image`. -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py deleted file mode 100644 index db9e5c8..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/control_video.mp4" -) - -# Control video -control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) -video = pipe( - prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=control_video, height=832, width=576, num_frames=49, - seed=1, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py deleted file mode 100644 index af227cb..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# First and last frame to video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=image, - seed=0, tiled=True - # You can input `end_image=xxx` to control the last frame of the video. - # The model will automatically generate the dynamic content between `input_image` and `end_image`. -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py deleted file mode 100644 index d0f5fab..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download -from dchen.camera_compute import process_pose_file - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] -) - -# Control video -control_video = None -reference_image = None -control_camera_text = "/mnt/nas2/dchen/Work/add_0609/DiffSynth-Studio/dchen/camera_information.txt" -input_image = Image.open("/mnt/nas2/dchen/Work/add_0609/DiffSynth-Studio/dchen/7.png") -sigma_shift = 3 -height = 480 -width = 832 - -control_camera_video = process_pose_file(control_camera_text, width, height) - -video = pipe( - prompt="一个小女孩正在户外玩耍。她穿着一件蓝色的短袖上衣和粉色的短裤,头发扎成一个可爱的辫子。她的脚上没有穿鞋,显得非常自然和随意。她正用一把红色的小铲子在泥土里挖土,似乎在进行某种有趣的活动,可能是种花或是挖掘宝藏。地上有一根长长的水管,可能是用来浇水的。背景是一片草地和一些绿色植物,阳光明媚,整个场景充满了童趣和生机。小女孩专注的表情和认真的动作让人感受到她的快乐和好奇心。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=control_video, reference_image=reference_image, - height=height, width=width, num_frames=81, - seed=1, tiled=True, - - control_camera_video = control_camera_video, - input_image = input_image, - sigma_shift = sigma_shift, -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py deleted file mode 100644 index 0f7e4c8..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] -) - -# Control video -control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) -reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) -video = pipe( - prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=control_video, reference_image=reference_image, - height=832, width=576, num_frames=49, - seed=1, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py deleted file mode 100644 index 1e43c59..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download -from dchen.camera_compute import process_pose_file - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() -print("success!") - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] -) - -# Control video -control_video = None -reference_image = None -control_camera_text = "/mnt/nas2/dchen/Work/add_0609/DiffSynth-Studio/dchen/camera_information.txt" -input_image = Image.open("/mnt/nas2/dchen/Work/add_0609/DiffSynth-Studio/dchen/7.png") -sigma_shift = 3 -height = 480 -width = 832 - -control_camera_video = process_pose_file(control_camera_text, width, height) - -video = pipe( - prompt="一个小女孩正在户外玩耍。她穿着一件蓝色的短袖上衣和粉色的短裤,头发扎成一个可爱的辫子。她的脚上没有穿鞋,显得非常自然和随意。她正用一把红色的小铲子在泥土里挖土,似乎在进行某种有趣的活动,可能是种花或是挖掘宝藏。地上有一根长长的水管,可能是用来浇水的。背景是一片草地和一些绿色植物,阳光明媚,整个场景充满了童趣和生机。小女孩专注的表情和认真的动作让人感受到她的快乐和好奇心。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=control_video, reference_image=reference_image, - height=height, width=width, num_frames=81, - seed=1, tiled=True, - - control_camera_video = control_camera_video, - input_image = input_image, - sigma_shift = sigma_shift, -) -save_video(video, "video2.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py deleted file mode 100644 index 78635ff..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] -) - -# Control video -control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) -reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) -video = pipe( - prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=control_video, reference_image=reference_image, - height=832, width=576, num_frames=49, - seed=1, tiled=True -) -save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py deleted file mode 100644 index 334e981..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# First and last frame to video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=image, - seed=0, tiled=True - # You can input `end_image=xxx` to control the last frame of the video. - # The model will automatically generate the dynamic content between `input_image` and `end_image`. -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py deleted file mode 100644 index eb2e5b0..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# Image-to-video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=image, - seed=0, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py deleted file mode 100644 index fb14d24..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# Image-to-video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=image, - seed=0, tiled=True, - height=720, width=1280, -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py b/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py deleted file mode 100644 index 40cb02d..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management() - -# Text-to-video -video = pipe( - prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=True, -) -save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh deleted file mode 100644 index e70fd13..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh +++ /dev/null @@ -1,13 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.motion_controller." \ - --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_full" \ - --trainable_models "motion_controller" \ - --input_contains_motion_bucket_id \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh deleted file mode 100644 index c0591ca..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \ - --trainable_models "dit" \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh deleted file mode 100644 index 499c787..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ - --data_file_keys "video,control_video" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-1.3B-Control_full" \ - --trainable_models "dit" \ - --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh deleted file mode 100644 index 1fec876..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-1.3B-InP_full" \ - --trainable_models "dit" \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh deleted file mode 100644 index 2d7272d..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ - --data_file_keys "video,control_video" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-14B-Control_full" \ - --trainable_models "dit" \ - --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh deleted file mode 100644 index 3463670..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-14B-InP_full" \ - --trainable_models "dit" \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh deleted file mode 100644 index 5acda18..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh +++ /dev/null @@ -1,15 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ - --data_file_keys "video,control_video,reference_image" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_full" \ - --trainable_models "dit" \ - --input_contains_control_video \ - --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh deleted file mode 100644 index d3b280f..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_full" \ - --trainable_models "dit" \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh deleted file mode 100644 index 2a63311..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh +++ /dev/null @@ -1,15 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ - --data_file_keys "video,control_video,reference_image" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_full" \ - --trainable_models "dit" \ - --input_contains_control_video \ - --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh deleted file mode 100644 index 11e7cc3..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_full" \ - --trainable_models "dit" \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh deleted file mode 100644 index 5cea09b..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh +++ /dev/null @@ -1,13 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-I2V-14B-480P_full" \ - --trainable_models "dit" \ - --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh deleted file mode 100644 index 4b0ed11..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh +++ /dev/null @@ -1,13 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-I2V-14B-720P_full" \ - --trainable_models "dit" \ - --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh deleted file mode 100644 index e0d6e84..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh +++ /dev/null @@ -1,12 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-T2V-1.3B_full" \ - --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh deleted file mode 100644 index ae804b0..0000000 --- a/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh +++ /dev/null @@ -1,12 +0,0 @@ -accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-T2V-14B_full" \ - --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/accelerate_config_14B.yaml b/examples/wanvideo/model_training/full/accelerate_config_14B.yaml deleted file mode 100644 index 3875a9d..0000000 --- a/examples/wanvideo/model_training/full/accelerate_config_14B.yaml +++ /dev/null @@ -1,22 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - gradient_accumulation_steps: 1 - offload_optimizer_device: cpu - offload_param_device: cpu - zero3_init_flag: false - zero_stage: 2 -distributed_type: DEEPSPEED -downcast_bf16: 'no' -enable_cpu_affinity: false -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 8 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/wanvideo/model_training/full/run_test.py b/examples/wanvideo/model_training/full/run_test.py deleted file mode 100644 index 093becd..0000000 --- a/examples/wanvideo/model_training/full/run_test.py +++ /dev/null @@ -1,38 +0,0 @@ -import multiprocessing, os - - -def run_task(scripts, thread_id, thread_num): - for script_id, script in enumerate(scripts): - if script_id % thread_num == thread_id: - log_file_name = script.replace("/", "_") + ".txt" - cmd = f"CUDA_VISIBLE_DEVICES={thread_id} bash {script} > data/log/{log_file_name} 2>&1" - os.makedirs("data/log", exist_ok=True) - print(cmd, flush=True) - os.system(cmd) - - -if __name__ == "__main__": - # 1.3B - scripts = [] - for file_name in os.listdir("examples/wanvideo/model_training/full"): - if file_name != "run_test.py" and "14B" not in file_name: - scripts.append(os.path.join("examples/wanvideo/model_training/full", file_name)) - - processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] - for p in processes: - p.start() - for p in processes: - p.join() - - # 14B - scripts = [] - for file_name in os.listdir("examples/wanvideo/model_training/full"): - if file_name != "run_test.py" and "14B" in file_name: - scripts.append(os.path.join("examples/wanvideo/model_training/full", file_name)) - for script in scripts: - log_file_name = script.replace("/", "_") + ".txt" - cmd = f"bash {script} > data/log/{log_file_name} 2>&1" - print(cmd, flush=True) - os.system(cmd) - - print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh deleted file mode 100644 index 4fb08bd..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh +++ /dev/null @@ -1,15 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_motion_bucket_id \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh deleted file mode 100644 index 8b98631..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh +++ /dev/null @@ -1,16 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh deleted file mode 100644 index 72522f2..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh +++ /dev/null @@ -1,16 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ - --data_file_keys "video,control_video" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-1.3B-Control_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh deleted file mode 100644 index 182fccc..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh +++ /dev/null @@ -1,16 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-1.3B-InP_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh deleted file mode 100644 index a45203c..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh +++ /dev/null @@ -1,16 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ - --data_file_keys "video,control_video" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-14B-Control_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh deleted file mode 100644 index 5392658..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh +++ /dev/null @@ -1,16 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-14B-InP_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_input_image \ - --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh deleted file mode 100644 index a342981..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh +++ /dev/null @@ -1,17 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ - --data_file_keys "video,control_video,reference_image" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_control_video \ - --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh deleted file mode 100644 index a902522..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh +++ /dev/null @@ -1,17 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ - --data_file_keys "video,control_video,reference_image" \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_control_video \ - --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh deleted file mode 100644 index 3c085fa..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh +++ /dev/null @@ -1,15 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-I2V-14B-480P_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh deleted file mode 100644 index 6193df7..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh +++ /dev/null @@ -1,15 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ - --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh deleted file mode 100644 index d16a287..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-T2V-1.3B_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh deleted file mode 100644 index 1fb55ac..0000000 --- a/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh +++ /dev/null @@ -1,14 +0,0 @@ -accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata.csv \ - --height 480 \ - --width 832 \ - --dataset_repeat 100 \ - --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.1-T2V-14B_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/run_test.py b/examples/wanvideo/model_training/lora/run_test.py deleted file mode 100644 index ec0f9e2..0000000 --- a/examples/wanvideo/model_training/lora/run_test.py +++ /dev/null @@ -1,25 +0,0 @@ -import multiprocessing, os - - -def run_task(scripts, thread_id, thread_num): - for script_id, script in enumerate(scripts): - if script_id % thread_num == thread_id: - log_file_name = script.replace("/", "_") + ".txt" - cmd = f"CUDA_VISIBLE_DEVICES={thread_id} bash {script} > data/log/{log_file_name} 2>&1" - os.makedirs("data/log", exist_ok=True) - print(cmd, flush=True) - os.system(cmd) - - -if __name__ == "__main__": - scripts = [] - for file_name in os.listdir("examples/wanvideo/model_training/lora"): - if file_name != "run_test.py": - scripts.append(os.path.join("examples/wanvideo/model_training/lora", file_name)) - - processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] - for p in processes: - p.start() - for p in processes: - p.join() - print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py deleted file mode 100644 index cbace5a..0000000 --- a/examples/wanvideo/model_training/train.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch, os, json -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - - -class WanTrainingModule(DiffusionTrainingModule): - def __init__( - self, - model_paths=None, model_id_with_origin_paths=None, - trainable_models=None, - lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, - use_gradient_checkpointing=True, - use_gradient_checkpointing_offload=False, - # Extra inputs - input_contains_input_image=False, - input_contains_end_image=False, - input_contains_control_video=False, - input_contains_reference_image=False, - input_contains_vace_video=False, - input_contains_vace_reference_image=False, - input_contains_motion_bucket_id=False, - ): - super().__init__() - # Load models - model_configs = [] - if model_paths is not None: - model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path) for path in model_paths] - if model_id_with_origin_paths is not None: - model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] - self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - - # Reset training scheduler - self.pipe.scheduler.set_timesteps(1000, training=True) - - # Freeze untrainable models - self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) - - # Add LoRA to the base models - if lora_base_model is not None: - model = self.add_lora_to_model( - getattr(self.pipe, lora_base_model), - target_modules=lora_target_modules.split(","), - lora_rank=lora_rank - ) - setattr(self.pipe, lora_base_model, model) - - # Store other configs - self.use_gradient_checkpointing = use_gradient_checkpointing - self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload - self.input_contains_input_image = input_contains_input_image - self.input_contains_end_image = input_contains_end_image - self.input_contains_control_video = input_contains_control_video - self.input_contains_reference_image = input_contains_reference_image - self.input_contains_vace_video = input_contains_vace_video - self.input_contains_vace_reference_image = input_contains_vace_reference_image - self.input_contains_motion_bucket_id = input_contains_motion_bucket_id - - - def forward_preprocess(self, data): - # CFG-sensitive parameters - inputs_posi = {"prompt": data["prompt"]} - inputs_nega = {} - - # CFG-unsensitive parameters - inputs_shared = { - # Assume you are using this pipeline for inference, - # please fill in the input parameters. - "input_video": data["video"], - "height": data["video"][0].size[1], - "width": data["video"][0].size[0], - "num_frames": len(data["video"]), - # Please do not modify the following parameters - # unless you clearly know what this will cause. - "cfg_scale": 1, - "tiled": False, - "rand_device": self.pipe.device, - "use_gradient_checkpointing": self.use_gradient_checkpointing, - "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, - "cfg_merge": False, - "vace_scale": 1, - } - - # Extra inputs - if self.input_contains_input_image: inputs_shared["input_image"] = data["video"][0] - if self.input_contains_end_image: inputs_shared["end_image"] = data["video"][-1] - if self.input_contains_control_video: inputs_shared["control_video"] = data["control_video"] - if self.input_contains_reference_image: inputs_shared["reference_image"] = data["reference_image"] - if self.input_contains_vace_video: inputs_shared["vace_video"] = data["vace_video"] - if self.input_contains_vace_reference_image: inputs_shared["vace_reference_image"] = data["vace_reference_image"] - if self.input_contains_motion_bucket_id: inputs_shared["motion_bucket_id"] = data["motion_bucket_id"] - - # Pipeline units will automatically process the input parameters. - for unit in self.pipe.units: - inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) - return {**inputs_shared, **inputs_posi} - - - def forward(self, data, inputs=None): - if inputs is None: inputs = self.forward_preprocess(data) - models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} - loss = self.pipe.training_loss(**models, **inputs) - return loss - - -if __name__ == "__main__": - parser = wan_parser() - args = parser.parse_args() - dataset = VideoDataset(args=args) - model = WanTrainingModule( - model_paths=args.model_paths, - model_id_with_origin_paths=args.model_id_with_origin_paths, - trainable_models=args.trainable_models, - lora_base_model=args.lora_base_model, - lora_target_modules=args.lora_target_modules, - lora_rank=args.lora_rank, - use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, - input_contains_input_image=args.input_contains_input_image, - input_contains_end_image=args.input_contains_end_image, - input_contains_control_video=args.input_contains_control_video, - input_contains_reference_image=args.input_contains_reference_image, - input_contains_vace_video=args.input_contains_vace_video, - input_contains_vace_reference_image=args.input_contains_vace_reference_image, - input_contains_motion_bucket_id=args.input_contains_motion_bucket_id, - ) - launch_training_task(model, dataset, args=args) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py deleted file mode 100644 index 124749a..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-1.3b-speedcontrol-v1_full/epoch-1.safetensors") -pipe.motion_controller.load_state_dict(state_dict) -pipe.enable_vram_management() - -# Text-to-video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=1, tiled=True, - motion_bucket_id=50 -) -save_video(video, "video_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py deleted file mode 100644 index 41a67ed..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-FLF2V-14B-720P_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], - end_image=video[80], - seed=0, tiled=True, - sigma_shift=16, -) -save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py deleted file mode 100644 index 6726e9c..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-1.3B-Control_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py deleted file mode 100644 index 3e1e6f3..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-1.3B-InP_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], end_image=video[80], - seed=0, tiled=True -) -save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py deleted file mode 100644 index 08b0acb..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-14B-Control_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py deleted file mode 100644 index d7e39d7..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-14B-InP_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], end_image=video[80], - seed=0, tiled=True -) -save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py deleted file mode 100644 index 6497e1b..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-Control_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] -reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, reference_image=reference_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py deleted file mode 100644 index cd8ee20..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-InP_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], end_image=video[80], - seed=0, tiled=True -) -save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py deleted file mode 100644 index 0dd2516..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-Control_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] -reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, reference_image=reference_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py deleted file mode 100644 index 7e944b0..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-InP_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], end_image=video[80], - seed=0, tiled=True -) -save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py deleted file mode 100644 index c1c8615..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-480P_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=input_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py deleted file mode 100644 index a8610f3..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-720P_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=input_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py deleted file mode 100644 index 1420514..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-T2V-1.3B_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py deleted file mode 100644 index a0107ae..0000000 --- a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ], -) -state_dict = load_state_dict("models/train/Wan2.1-T2V-14B_full/epoch-1.safetensors") -pipe.dit.load_state_dict(state_dict) -pipe.enable_vram_management() - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/run_test.py b/examples/wanvideo/model_training/validate_full/run_test.py deleted file mode 100644 index a4e3203..0000000 --- a/examples/wanvideo/model_training/validate_full/run_test.py +++ /dev/null @@ -1,25 +0,0 @@ -import multiprocessing, os - - -def run_task(scripts, thread_id, thread_num): - for script_id, script in enumerate(scripts): - if script_id % thread_num == thread_id: - log_file_name = script.replace("/", "_") + ".txt" - cmd = f"CUDA_VISIBLE_DEVICES={thread_id} python -u {script} > data/log/{log_file_name} 2>&1" - os.makedirs("data/log", exist_ok=True) - print(cmd, flush=True) - os.system(cmd) - - -if __name__ == "__main__": - scripts = [] - for file_name in os.listdir("examples/wanvideo/model_training/validate_full"): - if file_name != "run_test.py": - scripts.append(os.path.join("examples/wanvideo/model_training/validate_full", file_name)) - - processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] - for p in processes: - p.start() - for p in processes: - p.join() - print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py deleted file mode 100644 index 167b871..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData, load_state_dict -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-1.3b-speedcontrol-v1_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -# Text-to-video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=1, tiled=True, - motion_bucket_id=50 -) -save_video(video, "video_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py deleted file mode 100644 index cd68f0e..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-FLF2V-14B-720P_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], - end_image=video[80], - seed=0, tiled=True, - sigma_shift=16, -) -save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py deleted file mode 100644 index 7270c38..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-1.3B-Control_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py deleted file mode 100644 index c904dfa..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-1.3B-InP_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], end_image=video[80], - seed=0, tiled=True -) -save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py deleted file mode 100644 index 8631d05..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-14B-Control_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py deleted file mode 100644 index e020aac..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-14B-InP_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) - -# First and last frame to video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=video[0], end_image=video[80], - seed=0, tiled=True -) -save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py deleted file mode 100644 index ebcfd2f..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] -reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, reference_image=reference_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py deleted file mode 100644 index 6b11098..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-14B-Control_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) -video = [video[i] for i in range(81)] -reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -# Control video -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - control_video=video, reference_image=reference_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py deleted file mode 100644 index 1687e36..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-480P_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=input_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py deleted file mode 100644 index 9893e26..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=input_image, - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py deleted file mode 100644 index 7cb6c02..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-T2V-1.3B_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py deleted file mode 100644 index 3b66a49..0000000 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig - - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ], -) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-T2V-14B_lora/epoch-4.safetensors", alpha=1) -pipe.enable_vram_management() - -video = pipe( - prompt="from sunset to night, a small town, light, house, river", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=1, tiled=True -) -save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/run_test.py b/examples/wanvideo/model_training/validate_lora/run_test.py deleted file mode 100644 index 367ee9d..0000000 --- a/examples/wanvideo/model_training/validate_lora/run_test.py +++ /dev/null @@ -1,25 +0,0 @@ -import multiprocessing, os - - -def run_task(scripts, thread_id, thread_num): - for script_id, script in enumerate(scripts): - if script_id % thread_num == thread_id: - log_file_name = script.replace("/", "_") + ".txt" - cmd = f"CUDA_VISIBLE_DEVICES={thread_id} python -u {script} > data/log/{log_file_name} 2>&1" - os.makedirs("data/log", exist_ok=True) - print(cmd, flush=True) - os.system(cmd) - - -if __name__ == "__main__": - scripts = [] - for file_name in os.listdir("examples/wanvideo/model_training/validate_lora"): - if file_name != "run_test.py": - scripts.append(os.path.join("examples/wanvideo/model_training/validate_lora", file_name)) - - processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] - for p in processes: - p.start() - for p in processes: - p.join() - print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py new file mode 100644 index 0000000..cd10096 --- /dev/null +++ b/examples/wanvideo/train_wan_t2v.py @@ -0,0 +1,593 @@ +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoPipeline, ModelManager, load_state_dict +from peft import LoraConfig, inject_adapter_in_model +import torchvision +from PIL import Image +import numpy as np + + + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = frame + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + first_frame = v2.functional.center_crop(first_frame, output_size=(self.height, self.width)) + first_frame = np.array(first_frame) + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + first_frame = frame + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + + self.pipe.device = self.device + if video is not None: + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, None, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + torch.save(data, path + ".tensors.pth") + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + + self.steps_per_epoch = steps_per_epoch + + + def __getitem__(self, index): + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path = self.path[data_id] + data = torch.load(path, weights_only=True, map_location="cpu") + return data + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + pretrained_lora_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + self.freeze_parameters() + if train_architecture == "lora": + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) + else: + self.pipe.denoising_model().requires_grad_(True) + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None): + # Add LoRA to UNet + self.lora_alpha = lora_alpha + if init_lora_weights == "kaiming": + init_lora_weights = True + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights=init_lora_weights, + target_modules=lora_target_modules.split(","), + ) + model = inject_adapter_in_model(lora_config, model) + for param in model.parameters(): + # Upcast LoRA parameters into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + # Lora pretrained lora weights + if pretrained_lora_path is not None: + state_dict = load_state_dict(pretrained_lora_path) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + all_keys = [i for i, _ in model.named_parameters()] + num_updated_keys = len(all_keys) - len(missing_keys) + num_unexpected_keys = len(unexpected_keys) + print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + lora_state_dict = {} + for name, param in state_dict.items(): + if name in trainable_param_names: + lora_state_dict[name] = param + checkpoint.update(lora_state_dict) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--task", + type=str, + default="data_process", + required=True, + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default=None, + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default="q,k,v,o,ffn.0,ffn.2", + help="Layers with LoRA modules.", + ) + parser.add_argument( + "--init_lora_weights", + type=str, + default="kaiming", + choices=["gaussian", "kaiming"], + help="The initializing method of LoRA weight.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="auto", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="The dimension of the LoRA update matrices.", + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=4.0, + help="The weight of the LoRA update matrices.", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--train_architecture", + type=str, + default="lora", + choices=["lora", "full"], + help="Model structure to train. LoRA training or full training.", + ) + parser.add_argument( + "--pretrained_lora_path", + type=str, + default=None, + help="Pretrained LoRA path. Required if the training is resumed.", + ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + train_architecture=args.train_architecture, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + pretrained_lora_path=args.pretrained_lora_path, + ) + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="wan", + name="wan", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=os.path.join(args.output_path, "swanlog"), + ) + logger = [swanlab_logger] + else: + logger = None + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) diff --git a/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/wan_1.3b_motion_controller.py similarity index 63% rename from examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py rename to examples/wanvideo/wan_1.3b_motion_controller.py index 6efdc65..8036819 100644 --- a/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py +++ b/examples/wanvideo/wan_1.3b_motion_controller.py @@ -1,25 +1,31 @@ import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") +snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + "models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors", ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. ) -pipe.enable_vram_management() +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) # Text-to-video video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, seed=1, tiled=True, motion_bucket_id=0 ) @@ -28,7 +34,8 @@ save_video(video, "video_slow.mp4", fps=15, quality=5) video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, seed=1, tiled=True, motion_bucket_id=100 ) -save_video(video, "video_fast.mp4", fps=15, quality=5) +save_video(video, "video_fast.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py b/examples/wanvideo/wan_1.3b_text_to_video.py similarity index 70% rename from examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py rename to examples/wanvideo/wan_1.3b_text_to_video.py index 83e300b..e444cd2 100644 --- a/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py +++ b/examples/wanvideo/wan_1.3b_text_to_video.py @@ -1,25 +1,30 @@ import torch -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. ) -pipe.enable_vram_management() +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) # Text-to-video video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=True, + num_inference_steps=50, + seed=0, tiled=True ) save_video(video, "video1.mp4", fps=15, quality=5) @@ -29,6 +34,7 @@ video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", input_video=video, denoising_strength=0.7, + num_inference_steps=50, seed=1, tiled=True ) save_video(video, "video2.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/wan_1.3b_vace.py similarity index 72% rename from examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py rename to examples/wanvideo/wan_1.3b_vace.py index 99c0242..fb987a7 100644 --- a/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py +++ b/examples/wanvideo/wan_1.3b_vace.py @@ -1,21 +1,26 @@ import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), +# Download models +snapshot_download("iic/VACE-Wan2.1-1.3B-Preview", local_dir="models/iic/VACE-Wan2.1-1.3B-Preview") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/iic/VACE-Wan2.1-1.3B-Preview/diffusion_pytorch_model.safetensors", + "models/iic/VACE-Wan2.1-1.3B-Preview/models_t5_umt5-xxl-enc-bf16.pth", + "models/iic/VACE-Wan2.1-1.3B-Preview/Wan2.1_VAE.pth", ], + torch_dtype=torch.bfloat16, ) -pipe.enable_vram_management() +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) +# Download example video dataset_snapshot_download( dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", @@ -27,6 +32,8 @@ control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + height=480, width=832, num_frames=81, vace_video=control_video, seed=1, tiled=True ) @@ -36,6 +43,8 @@ save_video(video, "video1.mp4", fps=15, quality=5) video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + height=480, width=832, num_frames=81, vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), seed=1, tiled=True ) @@ -45,6 +54,8 @@ save_video(video, "video2.mp4", fps=15, quality=5) video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + height=480, width=832, num_frames=81, vace_video=control_video, vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), seed=1, tiled=True diff --git a/examples/wanvideo/wan_14B_flf2v.py b/examples/wanvideo/wan_14B_flf2v.py new file mode 100644 index 0000000..23109df --- /dev/null +++ b/examples/wanvideo/wan_14B_flf2v.py @@ -0,0 +1,52 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("Wan-AI/Wan2.1-FLF2V-14B-720P", local_dir="models/Wan-AI/Wan2.1-FLF2V-14B-720P") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + ["models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], + torch_dtype=torch.float32, # Image Encoder is loaded with float32 +) +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00001-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00002-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00003-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00004-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00005-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00006-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00007-of-00007.safetensors", + ], + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Download example image +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] +) + +# First and last frame to video +video = pipe( + prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=30, + input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), + end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), + height=960, width=960, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_image_to_video.py b/examples/wanvideo/wan_14b_image_to_video.py new file mode 100644 index 0000000..91894ae --- /dev/null +++ b/examples/wanvideo/wan_14b_image_to_video.py @@ -0,0 +1,51 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-I2V-14B-480P") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + ["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], + torch_dtype=torch.float32, # Image Encoder is loaded with float32 +) +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors", + ], + "models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. + +# Download example image +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py new file mode 100644 index 0000000..654565d --- /dev/null +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -0,0 +1,36 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video1.mp4", fps=25, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py new file mode 100644 index 0000000..77c230c --- /dev/null +++ b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py @@ -0,0 +1,149 @@ +import torch +import lightning as pl +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import parallelize_module +from lightning.pytorch.strategies import ModelParallelStrategy +from diffsynth import ModelManager, WanVideoPipeline, save_video +from tqdm import tqdm +from modelscope import snapshot_download + + + +class ToyDataset(torch.utils.data.Dataset): + def __init__(self, tasks=[]): + self.tasks = tasks + + def __getitem__(self, data_id): + return self.tasks[data_id] + + def __len__(self): + return len(self.tasks) + + +class LitModel(pl.LightningModule): + def __init__(self): + super().__init__() + model_manager = ModelManager(device="cpu") + model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, + ) + self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") + + def configure_model(self): + tp_mesh = self.device_mesh["tensor_parallel"] + plan = { + "text_embedding.0": ColwiseParallel(), + "text_embedding.2": RowwiseParallel(), + "time_projection.1": ColwiseParallel(output_layouts=Replicate()), + "text_embedding.0": ColwiseParallel(), + "text_embedding.2": RowwiseParallel(), + "blocks.0": PrepareModuleInput( + input_layouts=(Replicate(), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), + ), + "head": PrepareModuleInput( + input_layouts=(Replicate(), None), + desired_input_layouts=(Replicate(), None), + use_local_output=True, + ) + } + self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan) + for block_id, block in enumerate(self.pipe.dit.blocks): + layer_tp_plan = { + "self_attn": PrepareModuleInput( + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Shard(1), Shard(0)), + ), + "self_attn.q": SequenceParallel(), + "self_attn.k": SequenceParallel(), + "self_attn.v": SequenceParallel(), + "self_attn.norm_q": SequenceParallel(), + "self_attn.norm_k": SequenceParallel(), + "self_attn.attn": PrepareModuleInput( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(2), Shard(2), Shard(2)), + ), + "self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()), + + "cross_attn": PrepareModuleInput( + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Shard(1), Replicate()), + ), + "cross_attn.q": SequenceParallel(), + "cross_attn.k": SequenceParallel(), + "cross_attn.v": SequenceParallel(), + "cross_attn.norm_q": SequenceParallel(), + "cross_attn.norm_k": SequenceParallel(), + "cross_attn.attn": PrepareModuleInput( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(2), Shard(2), Shard(2)), + ), + "cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False), + + "ffn.0": ColwiseParallel(input_layouts=Shard(1)), + "ffn.2": RowwiseParallel(output_layouts=Replicate()), + + "norm1": SequenceParallel(use_local_output=True), + "norm2": SequenceParallel(use_local_output=True), + "norm3": SequenceParallel(use_local_output=True), + "gate": PrepareModuleInput( + input_layouts=(Shard(1), Replicate(), Replicate()), + desired_input_layouts=(Replicate(), Replicate(), Replicate()), + ) + } + parallelize_module( + module=block, + device_mesh=tp_mesh, + parallelize_plan=layer_tp_plan, + ) + + + def test_step(self, batch): + data = batch[0] + data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x + output_path = data.pop("output_path") + with torch.no_grad(), torch.inference_mode(False): + video = self.pipe(**data) + if self.local_rank == 0: + save_video(video, output_path, fps=15, quality=5) + + +if __name__ == "__main__": + snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") + dataloader = torch.utils.data.DataLoader( + ToyDataset([ + { + "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "num_inference_steps": 50, + "seed": 0, + "tiled": False, + "output_path": "video1.mp4", + }, + { + "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "num_inference_steps": 50, + "seed": 1, + "tiled": False, + "output_path": "video2.mp4", + }, + ]), + collate_fn=lambda x: x + ) + model = LitModel() + trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy()) + trainer.test(model, dataloader) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/wan_fun_InP.py similarity index 56% rename from examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py rename to examples/wanvideo/wan_fun_InP.py index f2fc560..ae23ee0 100644 --- a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py +++ b/examples/wanvideo/wan_fun_InP.py @@ -1,22 +1,27 @@ import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), +# Download models +snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors", + "models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth", + "models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth", + "models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. ) -pipe.enable_vram_management() +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) +# Download example image dataset_snapshot_download( dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", @@ -24,13 +29,14 @@ dataset_snapshot_download( ) image = Image.open("data/examples/wan/input_image.jpg") -# First and last frame to video +# Image-to-video video = pipe( prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, input_image=image, - seed=0, tiled=True # You can input `end_image=xxx` to control the last frame of the video. # The model will automatically generate the dynamic content between `input_image` and `end_image`. + seed=1, tiled=True ) -save_video(video, "video.mp4", fps=15, quality=5) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_fun_control.py b/examples/wanvideo/wan_fun_control.py new file mode 100644 index 0000000..e2c4d0c --- /dev/null +++ b/examples/wanvideo/wan_fun_control.py @@ -0,0 +1,40 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors", + "models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth", + "models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth", + "models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Download example video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control-to-video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/requirements.txt b/requirements.txt index 92d8b48..63a871b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch>=2.0.0 torchvision cupy-cuda12x -transformers +transformers==4.46.2 controlnet-aux==0.0.7 imageio imageio[ffmpeg] @@ -11,4 +11,3 @@ sentencepiece protobuf modelscope ftfy -pynvml