diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 6bb9350..f39b87a 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -131,7 +131,12 @@ model_loader_configs = [ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"), + (None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"), + (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"), + (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), + (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), + (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), diff --git a/diffsynth/lora/__init__.py b/diffsynth/lora/__init__.py new file mode 100644 index 0000000..33bd89c --- /dev/null +++ b/diffsynth/lora/__init__.py @@ -0,0 +1,45 @@ +import torch + + + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_B." not in key: + continue + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) + return lora_name_dict + + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + updated_num = 0 + lora_name_dict = self.get_name_dict(state_dict_lora) + for name, module in model.named_modules(): + if name in lora_name_dict: + weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict = module.state_dict() + state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict) + updated_num += 1 + print(f"{updated_num} tensors are updated by LoRA.") diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py index 99f5dee..0d58e4e 100644 --- a/diffsynth/models/utils.py +++ b/diffsynth/models/utils.py @@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None): return state_dict -def load_state_dict(file_path, torch_dtype=None): +def load_state_dict(file_path, torch_dtype=None, device="cpu"): if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) -def load_state_dict_from_safetensors(file_path, torch_dtype=None): +def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): state_dict = {} - with safe_open(file_path, framework="pt", device="cpu") as f: + with safe_open(file_path, framework="pt", device=device) as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: @@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None): return state_dict -def load_state_dict_from_bin(file_path, torch_dtype=None): - state_dict = torch.load(file_path, map_location="cpu", weights_only=True) +def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): + state_dict = torch.load(file_path, map_location=device, weights_only=True) if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): diff --git a/diffsynth/models/wan_video_camera_controller.py b/diffsynth/models/wan_video_camera_controller.py new file mode 100644 index 0000000..026b558 --- /dev/null +++ b/diffsynth/models/wan_video_camera_controller.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +import os +from typing_extensions import Literal + +class SimpleAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): + super(SimpleAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + def process_camera_coordinates( + self, + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], + length: int, + height: int, + width: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + ): + if origin is None: + origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + coordinates = generate_camera_coordinates(direction, length, speed, origin) + plucker_embedding = process_pose_file(coordinates, width, height) + return plucker_embedding + + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + w2c_mat = np.array(entry[7:]).reshape(3, 4) + w2c_mat_4x4 = np.eye(4) + w2c_mat_4x4[:3, :] = w2c_mat + self.w2c_mat = w2c_mat_4x4 + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) + +def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + +def custom_meshgrid(*args): + # torch>=2.0.0 only + return torch.meshgrid(*args, indexing='ij') + + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.linalg.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + + +def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): + if return_poses: + return cam_params + else: + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + + + +def generate_camera_coordinates( + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], + length: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) +): + coordinates = [list(origin)] + while len(coordinates) < length: + coor = coordinates[-1].copy() + if "Left" in direction: + coor[9] += speed + if "Right" in direction: + coor[9] -= speed + if "Up" in direction: + coor[13] += speed + if "Down" in direction: + coor[13] -= speed + coordinates.append(coor) + return coordinates diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index d9be8ab..50c06bf 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,6 +5,7 @@ import math from typing import Tuple, Optional from einops import rearrange from .utils import hash_state_dict_keys +from .wan_video_camera_controller import SimpleAdapter try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True @@ -272,6 +273,9 @@ class WanModel(torch.nn.Module): num_layers: int, has_image_input: bool, has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, ): super().__init__() self.dim = dim @@ -303,10 +307,21 @@ class WanModel(torch.nn.Module): if has_image_input: self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.control_adapter = None - def patchify(self, x: torch.Tensor): + def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None): x = self.patch_embedding(x) + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + x = [u + v for u, v in zip(x, y_camera)] + x = x[0].unsqueeze(0) grid_size = x.shape[2:] x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() return x, grid_size # x, grid_size: (f, h, w) @@ -532,6 +547,7 @@ class WanModelStateDictConverter: "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": + # 1.3B PAI control config = { "has_image_input": True, "patch_size": [1, 2, 2], @@ -546,6 +562,7 @@ class WanModelStateDictConverter: "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": + # 14B PAI control config = { "has_image_input": True, "patch_size": [1, 2, 2], @@ -574,6 +591,74 @@ class WanModelStateDictConverter: "eps": 1e-6, "has_image_pos_emb": True } + elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504": + # 1.3B PAI control v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6, + "has_ref_conv": True + } + elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b": + # 14B PAI control v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": True + } + elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901": + # 1.3B PAI control-camera v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 32, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6, + "has_ref_conv": False, + "add_control_adapter": True, + "in_dim_control_adapter": 24, + } + elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae": + # 14B PAI control-camera v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 32, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": False, + "add_control_adapter": True, + "in_dim_control_adapter": 24, + } else: config = {} return state_dict, config diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py index 0c9c2d7..40f3804 100644 --- a/diffsynth/models/wan_video_vace.py +++ b/diffsynth/models/wan_video_vace.py @@ -1,6 +1,6 @@ import torch from .wan_video_dit import DiTBlock - +from .utils import hash_state_dict_keys class VaceWanAttentionBlock(DiTBlock): def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): @@ -50,7 +50,11 @@ class VaceWanModel(torch.nn.Module): # vace patch embeddings self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size) - def forward(self, x, vace_context, context, t_mod, freqs): + def forward( + self, x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ): c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] c = torch.cat([ @@ -58,8 +62,27 @@ class VaceWanModel(torch.nn.Module): dim=1) for u in c ]) + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + for block in self.vace_blocks: - c = block(c, x, context, t_mod, freqs) + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + c, x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + c, x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + c = block(c, x, context, t_mod, freqs) hints = torch.unbind(c)[:-1] return hints @@ -74,4 +97,17 @@ class VaceWanModelDictConverter: def from_civitai(self, state_dict): state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")} - return state_dict_ + if hash_state_dict_keys(state_dict_) == '3b2726384e4f64837bdf216eea3f310d': # vace 14B + config = { + "vace_layers": (0, 5, 10, 15, 20, 25, 30, 35), + "vace_in_dim": 96, + "patch_size": (1, 2, 2), + "has_image_input": False, + "dim": 5120, + "num_heads": 40, + "ffn_dim": 13824, + "eps": 1e-06, + } + else: + config = {} + return state_dict_, config diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index df23076..137fd28 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -774,18 +774,11 @@ class WanVideoVAE(nn.Module): def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): - hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] - videos = [] - for hidden_state in hidden_states: - hidden_state = hidden_state.unsqueeze(0) - if tiled: - video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) - else: - video = self.single_decode(hidden_state, device) - video = video.squeeze(0) - videos.append(video) - videos = torch.stack(videos) - return videos + if tiled: + video = self.tiled_decode(hidden_states, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_states, device) + return video @staticmethod diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 77835a4..b84b1b9 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -68,6 +68,7 @@ class WanVideoPipeline(BasePipeline): torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, @@ -237,6 +238,18 @@ class WanVideoPipeline(BasePipeline): return latents + def prepare_reference_image(self, reference_image, height, width): + if reference_image is not None: + self.load_models_to_device(["vae"]) + reference_image = reference_image.resize((width, height)) + reference_image = self.preprocess_images([reference_image]) + reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device) + reference_latents = self.vae.encode(reference_image, device=self.device) + return {"reference_latents": reference_latents} + else: + return {} + + def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): if control_video is not None: control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) @@ -339,6 +352,7 @@ class WanVideoPipeline(BasePipeline): end_image=None, input_video=None, control_video=None, + reference_image=None, vace_video=None, vace_video_mask=None, vace_reference_image=None, @@ -398,6 +412,9 @@ class WanVideoPipeline(BasePipeline): else: image_emb = {} + # Reference image + reference_image_kwargs = self.prepare_reference_image(reference_image, height, width) + # ControlNet if control_video is not None: self.load_models_to_device(["image_encoder", "vae"]) @@ -435,14 +452,14 @@ class WanVideoPipeline(BasePipeline): self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, - **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, + **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs, ) if cfg_scale != 1.0: noise_pred_nega = model_fn_wan_video( self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, - **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, + **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -526,6 +543,7 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + reference_latents = None, vace_context = None, vace_scale = 1.0, tea_cache: TeaCache = None, @@ -552,6 +570,12 @@ def model_fn_wan_video( x, (f, h, w) = dit.patchify(x) + # Reference image + if reference_latents is not None: + reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + freqs = torch.cat([ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), @@ -580,6 +604,10 @@ def model_fn_wan_video( x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale if tea_cache is not None: tea_cache.store(x) + + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 x = dit.head(x, t) if use_unified_sequence_parallel: diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py new file mode 100644 index 0000000..44a64c9 --- /dev/null +++ b/diffsynth/pipelines/wan_video_new.py @@ -0,0 +1,1166 @@ +import torch, warnings, glob, os, types +import numpy as np +from PIL import Image +from einops import repeat, reduce +from typing import Optional, Union +from dataclasses import dataclass +from modelscope import snapshot_download +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal + +from ..models import ModelManager, load_state_dict +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel +from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..schedulers.flow_match import FlowMatchScheduler +from ..prompters import WanPrompter +from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm +from ..lora import GeneralLoRALoader + + + +class BasePipeline(torch.nn.Module): + + def __init__( + self, + device="cuda", torch_dtype=torch.float16, + height_division_factor=64, width_division_factor=64, + time_division_factor=None, time_division_remainder=None, + ): + super().__init__() + # The device and torch_dtype is used for the storage of intermediate variables, not models. + self.device = device + self.torch_dtype = torch_dtype + # The following parameters are used for shape check. + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + self.vram_management_enabled = False + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def check_resize_height_width(self, height, width, num_frames=None): + # Shape check + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") + if num_frames is None: + return height, width + else: + if num_frames % self.time_division_factor != self.time_division_remainder: + num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder + print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") + return height, width, num_frames + + + def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): + # Transform a PIL.Image to torch.Tensor + image = torch.Tensor(np.array(image, dtype=np.float32)) + image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + image = image * ((max_value - min_value) / 255) + min_value + image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) + return image + + + def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a list of PIL.Image to torch.Tensor + video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] + video = torch.stack(video, dim=pattern.index("T") // 2) + return video + + + def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to PIL.Image + if pattern != "H W C": + vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") + image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) + image = image.to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(image.numpy()) + return image + + + def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to list of PIL.Image + if pattern != "T H W C": + vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") + video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] + return video + + + def load_models_to_device(self, model_names=[]): + if self.vram_management_enabled: + # offload models + for name, model in self.named_children(): + if name not in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + else: + model.cpu() + torch.cuda.empty_cache() + # onload models + for name, model in self.named_children(): + if name in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + else: + model.to(self.device) + + + def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): + # Initialize Gaussian noise + generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) + noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + return noise + + + def enable_cpu_offload(self): + warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") + self.vram_management_enabled = True + + + def get_free_vram(self): + total_memory = torch.cuda.get_device_properties(self.device).total_memory + allocated_memory = torch.cuda.device_memory_used(self.device) + return (total_memory - allocated_memory) / (1024 ** 3) + + + def freeze_except(self, model_names): + for name, model in self.named_children(): + if name in model_names: + model.train() + model.requires_grad_(True) + else: + model.eval() + model.requires_grad_(False) + + +@dataclass +class ModelConfig: + path: Union[str, list[str]] = None + model_id: str = None + origin_file_pattern: Union[str, list[str]] = None + download_resource: str = "ModelScope" + offload_device: Optional[Union[str, torch.device]] = None + offload_dtype: Optional[torch.dtype] = None + + def download_if_necessary(self, local_model_path="./models", skip_download=False): + if self.path is None: + if self.model_id is None or self.origin_file_pattern is None: + raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") + if not skip_download: + downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) + snapshot_download( + self.model_id, + local_dir=os.path.join(local_model_path, self.model_id), + allow_file_pattern=self.origin_file_pattern, + ignore_file_pattern=downloaded_files, + local_files_only=False + ) + self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) + if isinstance(self.path, list) and len(self.path) == 1: + self.path = self.path[0] + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.prompter = WanPrompter(tokenizer_path=tokenizer_path) + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.in_iteration_models = ("dit", "motion_controller", "vace") + self.unit_runner = PipelineUnitRunner() + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_ImageEmbedder(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + ] + self.model_fn = model_fn_wan_video + + + def load_lora(self, module, path, alpha=1): + loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) + lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) + loader.load(module, lora, alpha=alpha) + + + def training_loss(self, **inputs): + timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) + + inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) + training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) + + noise_pred = self.model_fn(**inputs, timestep=timestep) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.scheduler.training_weight(timestep) + return loss + + + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): + self.vram_management_enabled = True + if num_persistent_param_in_dit is not None: + vram_limit = None + else: + if vram_limit is None: + vram_limit = self.get_free_vram() + vram_limit = vram_limit - vram_buffer + if self.text_encoder is not None: + dtype = next(iter(self.text_encoder.parameters())).dtype + enable_vram_management( + self.text_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + T5RelativeEmbedding: AutoWrappedModule, + T5LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + if self.dit is not None: + dtype = next(iter(self.dit.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + if self.vae is not None: + dtype = next(iter(self.vae.parameters())).dtype + enable_vram_management( + self.vae, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + RMS_norm: AutoWrappedModule, + CausalConv3d: AutoWrappedModule, + Upsample: AutoWrappedModule, + torch.nn.SiLU: AutoWrappedModule, + torch.nn.Dropout: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.image_encoder is not None: + dtype = next(iter(self.image_encoder.parameters())).dtype + enable_vram_management( + self.image_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) + if self.motion_controller is not None: + dtype = next(iter(self.motion_controller.parameters())).dtype + enable_vram_management( + self.motion_controller, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) + if self.vace is not None: + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.vace, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + + + def initialize_usp(self): + import torch.distributed as dist + from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment + dist.init_process_group(backend="nccl", init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), + ) + torch.cuda.set_device(dist.get_rank()) + + + def enable_usp(self): + from xfuser.core.distributed import get_sequence_parallel_world_size + from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), + local_model_path: str = "./models", + skip_download: bool = False, + redirect_common_files: bool = True, + use_usp=False, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", + "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern] + + # Download and load models + model_manager = ModelManager() + for model_config in model_configs: + model_config.download_if_necessary(local_model_path, skip_download=skip_download) + model_manager.load_model( + model_config.path, + device=model_config.offload_device or device, + torch_dtype=model_config.offload_dtype or torch_dtype + ) + + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: pipe.initialize_usp() + pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") + pipe.dit = model_manager.fetch_model("wan_video_dit") + pipe.vae = model_manager.fetch_model("wan_video_vae") + pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") + pipe.vace = model_manager.fetch_model("wan_video_vace") + + # Initialize tokenizer + tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download) + pipe.prompter.fetch_models(pipe.text_encoder) + pipe.prompter.fetch_tokenizer(tokenizer_config.path) + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + + # VACE (TODO: remove it) + if vace_reference_image is not None: + inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] + + # Decode + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device([]) + + return video + + + +class PipelineUnit: + def __init__( + self, + seperate_cfg: bool = False, + take_over: bool = False, + input_params: tuple[str] = None, + input_params_posi: dict[str, str] = None, + input_params_nega: dict[str, str] = None, + onload_model_names: tuple[str] = None + ): + self.seperate_cfg = seperate_cfg + self.take_over = take_over + self.input_params = input_params + self.input_params_posi = input_params_posi + self.input_params_nega = input_params_nega + self.onload_model_names = onload_model_names + + + def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict: + raise NotImplementedError("`process` is not implemented.") + + + +class PipelineUnitRunner: + def __init__(self): + pass + + def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + if unit.take_over: + # Let the pipeline unit take over this function. + inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) + elif unit.seperate_cfg: + # Positive side + processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_posi.update(processor_outputs) + # Negative side + if inputs_shared["cfg_scale"] != 1: + processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_nega.update(processor_outputs) + else: + inputs_nega.update(processor_outputs) + else: + processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_shared.update(processor_outputs) + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "num_frames")) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image")) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + length += 1 + noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(["vae"]) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if vace_reference_image is not None: + vace_reference_image = pipe.preprocess_video([vace_reference_image]) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + onload_model_names=("text_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("image_encoder", "vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context, "y": y} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -16:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(["vae"]) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image") + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image): + if camera_control_direction is None: + return {} + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__(input_params=("motion_bucket_id",)) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_mask is None: + vace_mask = torch.ones_like(vace_video) + else: + vace_mask = pipe.preprocess_video(vace_mask) + + inactive = vace_video * (1 - vace_mask) + 0 * vace_mask + reactive = vace_video * vace_mask + 0 * (1 - vace_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + vace_reference_image = pipe.preprocess_video([vace_reference_image]) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=()) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + if dit.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Add camera control + x, (f, h, w) = dit.patchify(x, control_camera_latents_input) + + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace(x, vace_context, context, t_mod, freqs) + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + if vace_context is not None and block_id in vace.vace_layers_mapping: + x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale + if tea_cache is not None: + tea_cache.store(x) + + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index d6d0219..9754b98 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -35,6 +35,9 @@ class FlowMatchScheduler(): y_shifted = y - y.min() bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing + self.training = True + else: + self.training = False def step(self, model_output, timestep, sample, to_final=False, **kwargs): diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py new file mode 100644 index 0000000..8d37277 --- /dev/null +++ b/diffsynth/trainers/utils.py @@ -0,0 +1,251 @@ +import imageio, os, torch, warnings, torchvision, argparse +from peft import LoraConfig, inject_adapter_in_model +from PIL import Image +import pandas as pd +from tqdm import tqdm +from accelerate import Accelerator + + + +class VideoDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, metadata_path=None, + frame_interval=1, num_frames=81, + dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + data_file_keys=("video",), + image_file_extension=("jpg", "jpeg", "png", "webp"), + video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), + repeat=1, + args=None, + ): + if args is not None: + base_path = args.dataset_base_path + metadata_path = args.dataset_metadata_path + height = args.height + width = args.width + num_frames = args.num_frames + data_file_keys = args.data_file_keys.split(",") + repeat = args.dataset_repeat + + metadata = pd.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + self.base_path = base_path + self.frame_interval = frame_interval + self.num_frames = num_frames + self.dynamic_resolution = dynamic_resolution + self.max_pixels = max_pixels + self.height = height + self.width = width + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.data_file_keys = data_file_keys + self.image_file_extension = image_file_extension + self.video_file_extension = video_file_extension + self.repeat = repeat + + if height is not None and width is not None and dynamic_resolution == True: + print("Height and width are fixed. Setting `dynamic_resolution` to False.") + self.dynamic_resolution = False + + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + + def get_height_width(self, image): + if self.dynamic_resolution: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + + def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames): + reader = imageio.get_reader(file_path) + if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame, *self.get_height_width(frame)) + frames.append(frame) + reader.close() + return frames + + + def load_image(self, file_path): + image = Image.open(file_path).convert("RGB") + image = self.crop_and_resize(image, *self.get_height_width(image)) + return image + + + def load_video(self, file_path): + frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + return file_ext_name.lower() in self.image_file_extension + + + def is_video(self, file_path): + file_ext_name = file_path.split(".")[-1] + return file_ext_name.lower() in self.video_file_extension + + + def load_data(self, file_path): + if self.is_image(file_path): + return self.load_image(file_path) + elif self.is_video(file_path): + return self.load_video(file_path) + else: + return None + + + def __getitem__(self, data_id): + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + path = os.path.join(self.base_path, data[key]) + data[key] = self.load_data(path) + if data[key] is None: + warnings.warn(f"cannot load file {data[key]}.") + return None + return data + + + def __len__(self): + return len(self.data) * self.repeat + + + +class DiffusionTrainingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + + def to(self, *args, **kwargs): + for name, model in self.named_children(): + model.to(*args, **kwargs) + return self + + + def trainable_modules(self): + trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) + return trainable_modules + + + def trainable_param_names(self): + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + return trainable_param_names + + + def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): + if lora_alpha is None: + lora_alpha = lora_rank + lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) + model = inject_adapter_in_model(lora_config, model) + return model + + + def export_trainable_state_dict(self, state_dict, remove_prefix=None): + trainable_param_names = self.trainable_param_names() + state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} + if remove_prefix is not None: + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith(remove_prefix): + name = name[len(remove_prefix):] + state_dict_[name] = param + state_dict = state_dict_ + return state_dict + + + +def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None): + if args is not None: + learning_rate = args.learning_rate + num_epochs = args.num_epochs + output_path = args.output_path + remove_prefix_in_ckpt = args.remove_prefix_in_ckpt + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) + + accelerator = Accelerator(gradient_accumulation_steps=1) + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + + for epoch in range(num_epochs): + for data in tqdm(dataloader): + with accelerator.accumulate(model): + optimizer.zero_grad() + loss = model(data) + accelerator.backward(loss) + optimizer.step() + scheduler.step() + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt) + os.makedirs(output_path, exist_ok=True) + path = os.path.join(output_path, f"epoch-{epoch}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + + +def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): + dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0]) + accelerator = Accelerator() + model, dataloader = accelerator.prepare(model, dataloader) + os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) + for data_id, data in enumerate(tqdm(dataloader)): + with torch.no_grad(): + inputs = model.forward_preprocess(data) + inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs} + torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")) + + + +def wan_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.") + parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") + parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.") + parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") + parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") + parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") + parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") + parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") + parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") + parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") + parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") + parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") + return parser + diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index a9df39e..dd4a245 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -8,8 +8,32 @@ def cast_to(weight, dtype, device): return r -class AutoWrappedModule(torch.nn.Module): - def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): +class AutoTorchModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def check_free_vram(self): + used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3) + return used_memory < self.vram_limit + + def offload(self): + if self.state != 0: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state != 1: + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def keep(self): + if self.state != 2: + self.to(dtype=self.computation_dtype, device=self.computation_device) + self.state = 2 + + +class AutoWrappedModule(AutoTorchModule): + def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs): super().__init__() self.module = module.to(dtype=offload_dtype, device=offload_device) self.offload_dtype = offload_dtype @@ -18,28 +42,57 @@ class AutoWrappedModule(torch.nn.Module): self.onload_device = onload_device self.computation_dtype = computation_dtype self.computation_device = computation_device + self.vram_limit = vram_limit self.state = 0 - def offload(self): - if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.module.to(dtype=self.offload_dtype, device=self.offload_device) - self.state = 0 - - def onload(self): - if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.module.to(dtype=self.onload_dtype, device=self.onload_device) - self.state = 1 - def forward(self, *args, **kwargs): - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + if self.state == 2: module = self.module else: - module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + module = self.module + elif self.vram_limit is not None and self.check_free_vram(): + self.keep() + module = self.module + else: + module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) return module(*args, **kwargs) -class AutoWrappedLinear(torch.nn.Linear): - def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): +class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule): + def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs): + with init_weights_on_device(device=torch.device("meta")): + super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) + self.weight = module.weight + self.bias = module.bias + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.vram_limit = vram_limit + self.state = 0 + + def forward(self, x, *args, **kwargs): + if self.state == 2: + weight, bias = self.weight, self.bias + else: + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + elif self.vram_limit is not None and self.check_free_vram(): + self.keep() + weight, bias = self.weight, self.bias + else: + weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + with torch.amp.autocast(device_type=x.device.type): + x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x) + return x + + +class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): + def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs): with init_weights_on_device(device=torch.device("meta")): super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) self.weight = module.weight @@ -50,29 +103,28 @@ class AutoWrappedLinear(torch.nn.Linear): self.onload_device = onload_device self.computation_dtype = computation_dtype self.computation_device = computation_device + self.vram_limit = vram_limit self.state = 0 - - def offload(self): - if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.to(dtype=self.offload_dtype, device=self.offload_device) - self.state = 0 - - def onload(self): - if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.to(dtype=self.onload_dtype, device=self.onload_device) - self.state = 1 + self.name = name def forward(self, x, *args, **kwargs): - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + if self.state == 2: weight, bias = self.weight, self.bias else: - weight = cast_to(self.weight, self.computation_dtype, self.computation_device) - bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + elif self.vram_limit is not None and self.check_free_vram(): + self.keep() + weight, bias = self.weight, self.bias + else: + weight = cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) return torch.nn.functional.linear(x, weight, bias) -def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""): for name, module in model.named_children(): + layer_name = name if name_prefix == "" else name_prefix + "." + name for source_module, target_module in module_map.items(): if isinstance(module, source_module): num_param = sum(p.numel() for p in module.parameters()) @@ -80,16 +132,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config_ = overflow_module_config else: module_config_ = module_config - module_ = target_module(module, **module_config_) + module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name) setattr(model, name, module_) total_num_param += num_param break else: - total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) + total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name) return total_num_param -def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): - enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) +def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None): + enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit) model.vram_management_enabled = True diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 92c3c59..cece7d3 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -1,6 +1,8 @@ -# Wan-Video +# Wan 2.1 -Wan-Video is a collection of video synthesis models open-sourced by Alibaba. +[切换到中文](./README_zh.md) + +Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba. Before using this model, please install DiffSynth-Studio from **source code**. @@ -10,267 +12,378 @@ cd DiffSynth-Studio pip install -e . ``` -## Model Zoo +| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| -|Developer|Name|Link|Scripts| -|-|-|-|-| -|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)| -|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)| -|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| -|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| -|Wan Team|14B first-last-frame-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|[wan_14B_flf2v.py](./wan_14B_flf2v.py)| -|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).| -|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).| -|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).| -|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)| -|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| -|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| -|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)| -|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)| -|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)| -Base model features +## Model Inference -||Text-to-video|Image-to-video|End frame|Control|Reference image| -|-|-|-|-|-|-| -|1.3B text-to-video|✅||||| -|14B text-to-video|✅||||| -|14B image-to-video 480P||✅|||| -|14B image-to-video 720P||✅|||| -|14B first-last-frame-to-video 720P||✅|✅||| -|1.3B InP||✅|✅||| -|14B InP||✅|✅||| -|1.3B Control||||✅|| -|14B Control||||✅|| -|1.3B VACE||||✅|✅| +The following sections will help you understand our functionalities and write inference code. -Adapter model compatibility +
-||1.3B text-to-video|1.3B InP|1.3B VACE| -|-|-|-|-| -|1.3B aesthetics LoRA|✅||✅| -|1.3B Highres-fix LoRA|✅||✅| -|1.3B ExVideo LoRA|✅||✅| -|1.3B Speed Control adapter|✅|✅|✅| +Loading the Model -## VRAM Usage +The model is loaded using `from_pretrained`: -* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], +) +``` -* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!). +Here, `torch_dtype` and `device` specify the computation precision and device respectively. The `model_configs` can be used to configure model paths in various ways: -We present a detailed table here. The model (14B text-to-video) is tested on a single A100. +* Downloading the model from [ModelScope](https://modelscope.cn/) and loading it. In this case, both `model_id` and `origin_file_pattern` need to be specified, for example: -|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting| -|-|-|-|-|-| -|torch.bfloat16|None (unlimited)|18.5s/it|48G|| -|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G|| -|torch.bfloat16|0|23.4s/it|10G|| -|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes| -|torch.float8_e4m3fn|0|24.0s/it|10G|| +```python +ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors") +``` -**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.** +* Loading the model from a local file path. In this case, the `path` parameter needs to be specified, for example: -## Efficient Attention Implementation +```python +ModelConfig(path="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") +``` -DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA. +For models that are loaded from multiple files, simply use a list, for example: -* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) -* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) -* [Sage Attention](https://github.com/thu-ml/SageAttention) -* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) +```python +ModelConfig(path=[ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", +]) +``` -## Acceleration +The `from_pretrained` function also provides additional parameters to control the behavior during model loading: -We support multiple acceleration solutions: -* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py). +* `tokenizer_config`: Path to the tokenizer of the Wan model. Default value is `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`. +* `local_model_path`: Path where downloaded models are saved. Default value is `"./models"`. +* `skip_download`: Whether to skip downloading models. Default value is `False`. When your network cannot access [ModelScope](https://modelscope.cn/), manually download the necessary files and set this to `True`. +* `redirect_common_files`: Whether to redirect duplicate model files. Default value is `True`. Since the Wan series models include multiple base models, some modules like text encoder are shared across these models. To avoid redundant downloads, we redirect the model paths. +* `use_usp`: Whether to enable Unified Sequence Parallel. Default value is `False`. Used for multi-GPU parallel inference. -* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py) +
-```bash +
+ +VRAM Management + +DiffSynth-Studio provides fine-grained VRAM management for the Wan model, allowing it to run on devices with limited VRAM. You can enable offloading functionality via the following code, which moves parts of the model to system memory on devices with limited VRAM: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() +``` + +FP8 quantization is also supported: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_dtype=torch.float8_e4m3fn), + ], +) +pipe.enable_vram_management() +``` + +Both FP8 quantization and offloading can be enabled simultaneously: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], +) +pipe.enable_vram_management() +``` + +FP8 quantization significantly reduces VRAM usage but does not accelerate computations. Some models may experience issues such as blurry, torn, or distorted outputs due to insufficient precision when using FP8 quantization. Use FP8 quantization with caution. + +The `enable_vram_management` function provides the following parameters to control VRAM usage: + +* `vram_limit`: VRAM usage limit (in GB). By default, it uses all available VRAM on the device. Note that this is not an absolute limit; if the specified VRAM is insufficient but more VRAM is actually available, inference will proceed using the minimum required VRAM. +* `vram_buffer`: Size of the VRAM buffer (in GB). Default is 0.5GB. Since certain large neural network layers may consume more VRAM unpredictably during their execution phase, a VRAM buffer is necessary. Ideally, this should match the maximum VRAM consumed by any single layer in the model. +* `num_persistent_param_in_dit`: Number of persistent parameters in DiT models. By default, there is no limit. We plan to remove this parameter in the future, so please avoid relying on it. + +
+ +
+ +Inference Acceleration + +Wan supports multiple acceleration techniques, including: + +* **Efficient attention implementations**: If any of these attention implementations are installed in your Python environment, they will be automatically enabled in the following priority: + * [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) + * [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) + * [Sage Attention](https://github.com/thu-ml/SageAttention) + * [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default setting; we recommend installing `torch>=2.5.0`) +* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command: + +```shell pip install xfuser>=0.4.3 -torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py +torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py ``` -* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py). +* **TeaCache**: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache). Please refer to [this example](./acceleration/teacache.py). -## Gallery +
-1.3B text-to-video. -https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8 +
-Put sunglasses on the dog. +Input Parameters -https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb +The pipeline accepts the following input parameters during inference: -14B text-to-video. +* `prompt`: Prompt describing the content to appear in the video. +* `negative_prompt`: Negative prompt describing content that should not appear in the video. Default is `""`. +* `input_image`: Input image, applicable for image-to-video models such as [`Wan-AI/Wan2.1-I2V-14B-480P`](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) and [`PAI/Wan2.1-Fun-1.3B-InP`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP), as well as first-and-last-frame models like [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P). +* `end_image`: End frame, applicable for first-and-last-frame models such as [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P). +* `input_video`: Input video used for video-to-video generation. Applicable to any Wan series model and must be used together with `denoising_strength`. +* `denoising_strength`: Denoising strength in range [0, 1]. A smaller value results in a video closer to `input_video`. +* `control_video`: Control video, applicable to Wan models with control capabilities such as [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control). +* `reference_image`: Reference image, applicable to Wan models supporting reference images such as [`PAI/Wan2.1-Fun-V1.1-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control). +* `camera_control_direction`: Camera control direction, optional values are "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown". Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera). +* `camera_control_speed`: Camera control speed. Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera). +* `camera_control_origin`: Origin coordinate of the camera control sequence. Please refer to the [original paper](https://arxiv.org/pdf/2404.02101) for proper configuration. Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera). +* `vace_video`: Input video for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview). +* `vace_video_mask`: Mask video for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview). +* `vace_reference_image`: Reference image for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview). +* `vace_scale`: Influence of the VACE model on the base model, default is 1. Higher values increase control strength but may lead to visual artifacts or breakdowns. +* `seed`: Random seed. Default is `None`, meaning fully random. +* `rand_device`: Device used to generate random Gaussian noise matrix. Default is `"cpu"`. When set to `"cuda"`, different GPUs may produce different generation results. +* `height`: Frame height, default is 480. Must be a multiple of 16; if not, it will be rounded up. +* `width`: Frame width, default is 832. Must be a multiple of 16; if not, it will be rounded up. +* `num_frames`: Number of frames, default is 81. Must be a multiple of 4 plus 1; if not, it will be rounded up, minimum is 1. +* `cfg_scale`: Classifier-free guidance scale, default is 5. Higher values increase adherence to the prompt but may cause visual artifacts. +* `cfg_merge`: Whether to merge both sides of classifier-free guidance for unified inference. Default is `False`. This parameter currently only works for basic text-to-video and image-to-video models. +* `num_inference_steps`: Number of inference steps, default is 50. +* `sigma_shift`: Parameter from Rectified Flow theory, default is 5. Higher values make the model stay longer at the initial denoising stage. Increasing this may improve video quality but may also cause inconsistency between generated videos and training data due to deviation from training behavior. +* `motion_bucket_id`: Motion intensity, range [0, 100], applicable to motion control modules such as [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1). Larger values indicate more intense motion. +* `tiled`: Whether to enable tiled VAE inference, default is `False`. Setting to `True` significantly reduces VRAM usage during VAE encoding/decoding but introduces small errors and slightly increases inference time. +* `tile_size`: Tile size during VAE encoding/decoding, default is (30, 52), only effective when `tiled=True`. +* `tile_stride`: Stride of tiles during VAE encoding/decoding, default is (15, 26), only effective when `tiled=True`. Must be less than or equal to `tile_size`. +* `sliding_window_size`: Sliding window size for DiT part. Experimental feature, effects are unstable. +* `sliding_window_stride`: Sliding window stride for DiT part. Experimental feature, effects are unstable. +* `tea_cache_l1_thresh`: Threshold for TeaCache. Larger values result in faster speed but lower quality. Note that after enabling TeaCache, the inference speed is not uniform, so the remaining time shown on the progress bar becomes inaccurate. +* `tea_cache_model_id`: TeaCache parameter template, options include `"Wan2.1-T2V-1.3B"`, `"Wan2.1-T2V-14B"`, `"Wan2.1-I2V-14B-480P"`, `"Wan2.1-I2V-14B-720P"`. +* `progress_bar_cmd`: Progress bar implementation, default is `tqdm.tqdm`. You can set it to `lambda x:x` to disable the progress bar. -https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f +
-14B image-to-video. +## Model Training -https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 +Wan series models are trained using a unified script located at [`./model_training/train.py`](./model_training/train.py). -14B first-last-frame-to-video +
-|First frame|Last frame|Video| -|-|-|-| -|![Image](https://github.com/user-attachments/assets/b0d8225b-aee0-4129-b8e5-58c8523221a6)|![Image](https://github.com/user-attachments/assets/2f0c9bc5-07e2-45fa-8320-53d63a4fd203)|https://github.com/user-attachments/assets/2a6a2681-622c-4512-b852-5f22e73830b1| +Script Parameters -## Train +The script includes the following parameters: -We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA: +* Dataset + * `--dataset_base_path`: Base path of the dataset. + * `--dataset_metadata_path`: Path to the metadata file of the dataset. + * `--height`: Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution. + * `--width`: Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution. + * `--num_frames`: Number of frames per video. Frames are sampled from the video prefix. + * `--data_file_keys`: Data file keys in the metadata. Comma-separated. + * `--dataset_repeat`: Number of times to repeat the dataset per epoch. +* Models + * `--model_paths`: Paths to load models. In JSON format. + * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. +* Training + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--output_path`: Output save path. + * `--remove_prefix_in_ckpt`: Remove prefix in ckpt. +* Trainable Modules + * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. + * `--lora_base_model`: Which model LoRA is added to. + * `--lora_target_modules`: Which layers LoRA is added to. + * `--lora_rank`: Rank of LoRA. +* Extra Inputs + * `--extra_inputs`: Additional model inputs, comma-separated. +* VRAM Management + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory. -https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9 +Additionally, the training framework is built upon [`accelerate`](https://huggingface.co/docs/accelerate/index). Before starting training, run `accelerate config` to configure GPU-related parameters. For certain training scripts (e.g., full fine-tuning of 14B models), we provide recommended `accelerate` configuration files, which can be found in the corresponding training scripts. -Step 1: Install additional packages +
+ + +
+ +Step 1: Prepare the Dataset + +The dataset consists of a series of files. We recommend organizing your dataset as follows: ``` -pip install peft lightning pandas -``` - -Step 2: Prepare your dataset - -You need to manage the training videos as follows: - -``` -data/example_dataset/ +data/example_video_dataset/ ├── metadata.csv -└── train - ├── video_00001.mp4 - └── image_00002.jpg +├── video1.mp4 +└── video2.mp4 ``` -`metadata.csv`: +Here, `video1.mp4` and `video2.mp4` are training video files, and `metadata.csv` is the metadata list, for example: ``` -file_name,text -video_00001.mp4,"video description" -image_00002.jpg,"video description" +video,prompt +video1.mp4,"from sunset to night, a small town, light, house, river" +video2.mp4,"a dog is running" ``` -We support both images and videos. An image is treated as a single frame of video. - -Step 3: Data process +We have prepared a sample video dataset to help you test. You can download it using the following command: ```shell -CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ - --task data_process \ - --dataset_path data/example_dataset \ - --output_path ./models \ - --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" \ - --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" \ - --tiled \ - --num_frames 81 \ - --height 480 \ - --width 832 +modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset ``` -After that, some cached files will be stored in the dataset folder. +The dataset supports mixed training of videos and images. Supported video formats include `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`, and supported image formats include `"jpg", "jpeg", "png", "webp"`. + +The resolution of videos can be controlled via script parameters `--height`, `--width`, and `--num_frames`. For each video, the first `num_frames` frames will be used for training; therefore, an error will occur if the video length is less than `num_frames`. Image files will be treated as single-frame videos. When both `--height` and `--width` are left empty, dynamic resolution will be enabled, meaning training will use the actual resolution of each video or image in the dataset. + +**We strongly recommend using fixed-resolution training and avoiding mixing images and videos in the same dataset due to load balancing issues in multi-GPU training.** + +When the model requires additional inputs, such as the `control_video` needed by control-capable models like [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control), please add corresponding columns in the metadata file, for example: ``` -data/example_dataset/ -├── metadata.csv -└── train - ├── video_00001.mp4 - ├── video_00001.mp4.tensors.pth - ├── video_00002.mp4 - └── video_00002.mp4.tensors.pth +video,prompt,control_video +video1.mp4,"from sunset to night, a small town, light, house, river",video1_softedge.mp4 ``` -Step 4: Train +If additional inputs contain video or image files, their column names need to be specified in the `--data_file_keys` parameter. The default value of this parameter is `"image,video"`, meaning it parses columns named `image` and `video`. You can extend this list based on the additional input requirements, for example: `--data_file_keys "image,video,control_video"`, and also enable `--input_contains_control_video`. -LoRA training: +
-```shell -CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ - --task train \ - --train_architecture lora \ - --dataset_path data/example_dataset \ - --output_path ./models \ - --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ - --steps_per_epoch 500 \ - --max_epochs 10 \ - --learning_rate 1e-4 \ - --lora_rank 16 \ - --lora_alpha 16 \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --accumulate_grad_batches 1 \ - --use_gradient_checkpointing -``` -Full training: +
-```shell -CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ - --task train \ - --train_architecture full \ - --dataset_path data/example_dataset \ - --output_path ./models \ - --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ - --steps_per_epoch 500 \ - --max_epochs 10 \ - --learning_rate 1e-4 \ - --accumulate_grad_batches 1 \ - --use_gradient_checkpointing -``` +Step 2: Load the Model -If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`. - -If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`. - -For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`. - -Step 5: Test - -Test LoRA: +Similar to the model loading logic during inference, you can configure the model to be loaded directly via its model ID. For instance, during inference we load the model using: ```python -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData - - -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") -model_manager.load_models([ - "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", -]) -model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0) -pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -video = pipe( - prompt="...", - negative_prompt="...", - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video.mp4", fps=30, quality=5) +model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), +] ``` -Test fine-tuned base model: +During training, simply use the following parameter to load the corresponding model: + +```shell +--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" +``` + +If you want to load the model from local files, for example during inference: ```python -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData - - -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") -model_manager.load_models([ - "models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", -]) -pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -video = pipe( - prompt="...", - negative_prompt="...", - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video.mp4", fps=30, quality=5) +model_configs=[ + ModelConfig(path=[ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ]), + ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"), +] ``` + +Then during training, set the parameter as: + +```shell +--model_paths '[ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors" + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth" +]' \ +``` + +
+ + +
+ +Step 3: Configure Trainable Modules + +The training framework supports full fine-tuning of base models or LoRA-based training. Here are some examples: + +* Full fine-tuning of the DiT module: `--trainable_models dit` +* Training a LoRA model for the DiT module: `--lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` +* Training both a LoRA model for DiT and the Motion Controller (yes, you can train such advanced structures): `--trainable_models motion_controller --lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` + +Additionally, since multiple modules (text encoder, dit, vae) are loaded in the training script, you need to remove prefixes when saving model files. For example, when fully fine-tuning the DiT module or training a LoRA version of DiT, please set `--remove_prefix_in_ckpt pipe.dit.` + +
+ + +
+ +Step 4: Launch the Training Process + +We have prepared training commands for each model. Please refer to the table at the beginning of this document. + +Note that full fine-tuning of the 14B model requires 8 GPUs, each with at least 80GB VRAM. During full fine-tuning of these 14B models, you must install `deepspeed` (`pip install deepspeed`). We have provided recommended [configuration files](./model_training/full/accelerate_config_14B.yaml), which will be loaded automatically in the corresponding training scripts. These scripts have been tested on 8*A100. + +The default video resolution in the training script is `480*832*81`. Increasing the resolution may cause out-of-memory errors. To reduce VRAM usage, add the parameter `--use_gradient_checkpointing_offload`. + +
diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md new file mode 100644 index 0000000..fc91990 --- /dev/null +++ b/examples/wanvideo/README_zh.md @@ -0,0 +1,392 @@ +# 通义万相 2.1(Wan 2.1) + +[Switch to English](./README.md) + +Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。 + +在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| + +## 模型推理 + +以下部分将会帮助您理解我们的功能并编写推理代码。 + + +
+ +加载模型 + +模型通过 `from_pretrained` 加载: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], +) +``` + +其中 `torch_dtype` 和 `device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径: + +* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id` 和 `origin_file_pattern`,例如 + +```python +ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors") +``` + +* 从本地文件路径加载模型。此时需要填写 `path`,例如 + +```python +ModelConfig(path="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") +``` + +对于从多个文件加载的单一模型,使用列表即可,例如 + +```python +ModelConfig(path=[ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", +]) +``` + +`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为: + +* `tokenizer_config`: Wan 模型的 tokenizer 路径,默认值为 `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`。 +* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。 +* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。 +* `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。 +* `use_usp`: 是否启用 Unified Sequence Parallel,默认值为 `False`。用于多 GPU 并行推理。 + +
+ + +
+ +显存管理 + +DiffSynth-Studio 为 Wan 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。 + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() +``` + +FP8 量化功能也是支持的: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_dtype=torch.float8_e4m3fn), + ], +) +pipe.enable_vram_management() +``` + +FP8 量化和 offload 可同时开启: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], +) +pipe.enable_vram_management() +``` + +FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在 FP8 量化下会出现精度不足导致的画面模糊、撕裂、失真问题,请谨慎使用 FP8 量化。 + +`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况: + +* `vram_limit`: 显存占用量(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。 +* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。 +* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。 + +
+ + +
+ +推理加速 + +Wan 支持多种加速方案,包括 + +* 高效注意力机制实现:当您的 Python 环境中安装过这些注意力机制实现方案时,我们将会按照以下优先级自动启用。 + * [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) + * [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) + * [Sage Attention](https://github.com/thu-ml/SageAttention) + * [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (默认设置,建议安装 `torch>=2.5.0`) +* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用以下命令运行: + +```shell +pip install xfuser>=0.4.3 +torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py +``` + +* TeaCache:加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。 + +
+ + +
+ +输入参数 + +Pipeline 在推理阶段能够接收以下输入参数: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `input_image`: 输入图片,适用于图生视频模型,例如 [`Wan-AI/Wan2.1-I2V-14B-480P`](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)、[`PAI/Wan2.1-Fun-1.3B-InP`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP),以及首尾帧模型,例如 [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P)。 +* `end_image`: 结尾帧,适用于首尾帧模型,例如 [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P)。 +* `input_video`: 输入视频,用于视频生视频,适用于任意 Wan 系列模型,需与参数 `denoising_strength` 配合使用。 +* `denoising_strength`: 去噪强度,范围为 [0, 1]。数值越小,生成的视频越接近 `input_video`。 +* `control_video`: 控制视频,适用于带控制能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)。 +* `reference_image`: 参考图片,适用于带参考图能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-V1.1-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)。 +* `camera_control_direction`: 镜头控制方向,可选 "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown" 之一,适用于 Camera-Control 模型,例如 [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)。 +* `camera_control_speed`: 镜头控制速度,适用于 Camera-Control 模型,例如 [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)。 +* `camera_control_origin`: 镜头控制序列的原点坐标,请参考[原论文](https://arxiv.org/pdf/2404.02101)进行设置,适用于 Camera-Control 模型,例如 [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)。 +* `vace_video`: VACE 模型的输入视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 +* `vace_video_mask`: VACE 模型的 mask 视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 +* `vace_reference_image`: VACE 模型的参考图片,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 +* `vace_scale`: VACE 模型对基础模型的影响程度,默认为1。数值越大,控制强度越高,但画面崩坏概率越大。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `height`: 帧高度,默认为 480。需设置为 16 的倍数,不满足时向上取整。 +* `width`: 帧宽度,默认为 832。需设置为 16 的倍数,不满足时向上取整。 +* `num_frames`: 帧数,默认为 81。需设置为 4 的倍数 + 1,不满足时向上取整,最小值为 1。 +* `cfg_scale`: Classifier-free guidance 机制的数值,默认为 5。数值越大,提示词的控制效果越强,但画面崩坏的概率越大。 +* `cfg_merge`: 是否合并 Classifier-free guidance 的两侧进行统一推理,默认为 `False`。该参数目前仅在基础的文生视频和图生视频模型上生效。 +* `num_inference_steps`: 推理次数,默认值为 50。 +* `sigma_shift`: Rectified Flow 理论中的参数,默认为 5。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的视频内容与训练数据存在差异。 +* `motion_bucket_id`: 运动幅度,范围为 [0, 100]。适用于速度控制模块,例如 [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1),数值越大,运动幅度越大。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size`: VAE 编解码阶段的分块大小,默认为 (30, 52),仅在 `tiled=True` 时生效。 +* `tile_stride`: VAE 编解码阶段的分块步长,默认为 (15, 26),仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 +* `sliding_window_size`: DiT 部分的滑动窗口大小。实验性功能,效果不稳定。 +* `sliding_window_stride`: DiT 部分的滑动窗口步长。实验性功能,效果不稳定。 +* `tea_cache_l1_thresh`: TeaCache 的阈值,数值越大,速度越快,画面质量越差。请注意,开启 TeaCache 后推理速度并非均匀,因此进度条上显示的剩余时间将会变得不准确。 +* `tea_cache_model_id`: TeaCache 的参数模板,可选 `"Wan2.1-T2V-1.3B"`、`Wan2.1-T2V-14B`、`Wan2.1-I2V-14B-480P`、`Wan2.1-I2V-14B-720P` 之一。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +
+ + +## 模型训练 + +Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。 + +
+ +脚本参数 + +脚本包含以下参数: + +* 数据集 + * `--dataset_base_path`: 数据集的根路径。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 + * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 +* 模型 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 +* 训练 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)数量。 + * `--output_path`: 保存路径。 + * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 +* 可训练模块 + * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪一层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 +* 额外模型输入 + * `--extra_inputs`: 额外的模型输入,以逗号分隔。 +* 显存管理 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + +此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如 14B 模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。 + +
+ + +
+ +Step 1: 准备数据集 + +数据集包含一系列文件,我们建议您这样组织数据集文件: + +``` +data/example_video_dataset/ +├── metadata.csv +├── video1.mp4 +└── video2.mp4 +``` + +其中 `video1.mp4`、`video2.mp4` 为训练用视频数据,`metadata.csv` 为元数据列表,例如 + +``` +video,prompt +video1.mp4,"from sunset to night, a small town, light, house, river" +video2.mp4,"a dog is running" +``` + +我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset +``` + +数据集支持视频和图片混合训练,支持的视频文件格式包括 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`,支持的图片格式包括 `"jpg", "jpeg", "png", "webp"`。 + +视频的尺寸可通过脚本参数 `--height`、`--width`、`--num_frames` 控制。在每个视频中,前 `num_frames` 帧会被用于训练,因此当视频长度不足 `num_frames` 帧时会报错,图片文件会被视为单帧视频。当 `--height` 和 `--width` 为空时将会开启动态分辨率,按照数据集中每个视频或图片的实际宽高训练。 + +**我们强烈建议使用固定分辨率训练,并避免图像和视频混合训练,因为在多卡训练中存在负载均衡问题。** + +当模型需要额外输入时,例如具备控制能力的模型 [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) 所需的 `control_video`,请在数据集中补充相应的列,例如: + +``` +video,prompt,control_video +video1.mp4,"from sunset to night, a small town, light, house, river",video1_softedge.mp4 +``` + +额外输入若包含视频和图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。该参数的默认值为 `"image,video"`,即解析列名为 `image` 和 `video` 的列。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,video,control_video"`,同时启用 `--input_contains_control_video`。 + +
+ + +
+ +Step 2: 加载模型 + +类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型 + +```python +model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), +] +``` + +那么在训练时,填入以下参数即可加载对应的模型。 + +```shell +--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" +``` + +如果您希望从本地文件加载模型,例如推理时 + +```python +model_configs=[ + ModelConfig(path=[ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ]), + ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"), +] +``` + +那么训练时需设置为 + +```shell +--model_paths '[ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors" + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth" +]' \ +``` + +
+ + +
+ +Step 3: 设置可训练模块 + +训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子: + +* 全量训练 DiT 部分:`--trainable_models dit` +* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` +* 训练 DiT 部分的 LoRA 和 Motion Controller 部分(是的,可以训练这种花里胡哨的结构):`--trainable_models motion_controller --lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` + +此外,由于训练脚本中加载了多个模块(text encoder、dit、vae),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.` + +
+ + +
+ +Step 4: 启动训练程序 + +我们为每一个模型编写了训练命令,请参考本文档开头的表格。 + +请注意,14B 模型全量训练需要8个GPU,每个GPU的显存至少为80G。全量训练这些14B模型时需要安装 `deepspeed`(`pip install deepspeed`),我们编写了建议的[配置文件](./model_training/full/accelerate_config_14B.yaml),这个配置文件会在对应的训练脚本中被加载,这些脚本已在 8*A100 上测试过。 + +训练脚本的默认视频尺寸为 `480*832*81`,提升分辨率将可能导致显存不足,请添加参数 `--use_gradient_checkpointing_offload` 降低显存占用。 + +
diff --git a/examples/wanvideo/wan_1.3b_text_to_video_accelerate.py b/examples/wanvideo/acceleration/teacache.py similarity index 59% rename from examples/wanvideo/wan_1.3b_text_to_video_accelerate.py rename to examples/wanvideo/acceleration/teacache.py index b56915c..b88656a 100644 --- a/examples/wanvideo/wan_1.3b_text_to_video_accelerate.py +++ b/examples/wanvideo/acceleration/teacache.py @@ -1,34 +1,27 @@ import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. ) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) +pipe.enable_vram_management() + -# Text-to-video video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, seed=0, tiled=True, # TeaCache parameters tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality. tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P). ) save_video(video, "video1.mp4", fps=15, quality=5) - -# TeaCache doesn't support video-to-video diff --git a/examples/wanvideo/acceleration/unified_sequence_parallel.py b/examples/wanvideo/acceleration/unified_sequence_parallel.py new file mode 100644 index 0000000..44b580b --- /dev/null +++ b/examples/wanvideo/acceleration/unified_sequence_parallel.py @@ -0,0 +1,27 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +import torch.distributed as dist + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + use_usp=True, + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + + +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +if dist.get_rank() == 0: + save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_1.3b_motion_controller.py b/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py similarity index 63% rename from examples/wanvideo/wan_1.3b_motion_controller.py rename to examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py index 8036819..6efdc65 100644 --- a/examples/wanvideo/wan_1.3b_motion_controller.py +++ b/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py @@ -1,31 +1,25 @@ import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") -snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", - "models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors", +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. ) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) +pipe.enable_vram_management() # Text-to-video video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, seed=1, tiled=True, motion_bucket_id=0 ) @@ -34,8 +28,7 @@ save_video(video, "video_slow.mp4", fps=15, quality=5) video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, seed=1, tiled=True, motion_bucket_id=100 ) -save_video(video, "video_fast.mp4", fps=15, quality=5) \ No newline at end of file +save_video(video, "video_fast.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py new file mode 100644 index 0000000..3061398 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] +) + +# First and last frame to video +video = pipe( + prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), + end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), + seed=0, tiled=True, + height=960, width=960, num_frames=33, + sigma_shift=16, +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py new file mode 100644 index 0000000..43374d2 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_fun_InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py similarity index 56% rename from examples/wanvideo/wan_fun_InP.py rename to examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py index ae23ee0..d921c0c 100644 --- a/examples/wanvideo/wan_fun_InP.py +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py @@ -1,27 +1,22 @@ import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download -# Download models -snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors", - "models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth", - "models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth", - "models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. ) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) +pipe.enable_vram_management() -# Download example image dataset_snapshot_download( dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", @@ -29,14 +24,13 @@ dataset_snapshot_download( ) image = Image.open("data/examples/wan/input_image.jpg") -# Image-to-video +# First and last frame to video video = pipe( prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, input_image=image, + seed=0, tiled=True # You can input `end_image=xxx` to control the last frame of the video. # The model will automatically generate the dynamic content between `input_image` and `end_image`. - seed=1, tiled=True ) -save_video(video, "video1.mp4", fps=15, quality=5) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py new file mode 100644 index 0000000..db9e5c8 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py new file mode 100644 index 0000000..af227cb --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py new file mode 100644 index 0000000..002c34c --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py @@ -0,0 +1,44 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py new file mode 100644 index 0000000..0f7e4c8 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py new file mode 100644 index 0000000..f2fc560 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py new file mode 100644 index 0000000..0e0c9a8 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py @@ -0,0 +1,44 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py new file mode 100644 index 0000000..78635ff --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py new file mode 100644 index 0000000..334e981 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py new file mode 100644 index 0000000..eb2e5b0 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py new file mode 100644 index 0000000..fb14d24 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py @@ -0,0 +1,35 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True, + height=720, width=1280, +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_1.3b_text_to_video.py b/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py similarity index 70% rename from examples/wanvideo/wan_1.3b_text_to_video.py rename to examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py index e444cd2..83e300b 100644 --- a/examples/wanvideo/wan_1.3b_text_to_video.py +++ b/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py @@ -1,30 +1,25 @@ import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. ) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) +pipe.enable_vram_management() # Text-to-video video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - seed=0, tiled=True + seed=0, tiled=True, ) save_video(video, "video1.mp4", fps=15, quality=5) @@ -34,7 +29,6 @@ video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", input_video=video, denoising_strength=0.7, - num_inference_steps=50, seed=1, tiled=True ) save_video(video, "video2.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py b/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py new file mode 100644 index 0000000..40cb02d --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py new file mode 100644 index 0000000..99c0242 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py @@ -0,0 +1,52 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video2.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video3.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_1.3b_vace.py b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py similarity index 72% rename from examples/wanvideo/wan_1.3b_vace.py rename to examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py index fb987a7..59ca2d4 100644 --- a/examples/wanvideo/wan_1.3b_vace.py +++ b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py @@ -1,26 +1,22 @@ import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download -# Download models -snapshot_download("iic/VACE-Wan2.1-1.3B-Preview", local_dir="models/iic/VACE-Wan2.1-1.3B-Preview") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/iic/VACE-Wan2.1-1.3B-Preview/diffusion_pytorch_model.safetensors", - "models/iic/VACE-Wan2.1-1.3B-Preview/models_t5_umt5-xxl-enc-bf16.pth", - "models/iic/VACE-Wan2.1-1.3B-Preview/Wan2.1_VAE.pth", - ], +pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], ) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) -# Download example video +pipe.enable_vram_management() + dataset_snapshot_download( dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", @@ -32,8 +28,6 @@ control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - height=480, width=832, num_frames=81, vace_video=control_video, seed=1, tiled=True ) @@ -43,8 +37,6 @@ save_video(video, "video1.mp4", fps=15, quality=5) video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - height=480, width=832, num_frames=81, vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), seed=1, tiled=True ) @@ -54,8 +46,6 @@ save_video(video, "video2.mp4", fps=15, quality=5) video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - height=480, width=832, num_frames=81, vace_video=control_video, vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), seed=1, tiled=True diff --git a/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py b/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py new file mode 100644 index 0000000..e504b2c --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py @@ -0,0 +1,54 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) + + +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video1_14b.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video2_14b.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video3_14b.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh new file mode 100644 index 0000000..3d580ab --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.motion_controller." \ + --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_full" \ + --trainable_models "motion_controller" \ + --extra_inputs "motion_bucket_id" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh new file mode 100644 index 0000000..c9f6f64 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh new file mode 100644 index 0000000..45a99de --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh new file mode 100644 index 0000000..a202bf9 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh new file mode 100644 index 0000000..8a17c3f --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh new file mode 100644 index 0000000..86feae7 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh new file mode 100644 index 0000000..b59ed32 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh new file mode 100644 index 0000000..34273c1 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh new file mode 100644 index 0000000..f6eed97 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh new file mode 100644 index 0000000..41b87e9 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh new file mode 100644 index 0000000..ce6640e --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh new file mode 100644 index 0000000..afb5d3d --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh new file mode 100644 index 0000000..6d257ff --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-480P_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh new file mode 100644 index 0000000..b64bcbf --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-720P_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh new file mode 100644 index 0000000..e0d6e84 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh @@ -0,0 +1,12 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-1.3B_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh new file mode 100644 index 0000000..ae804b0 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-14B_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh new file mode 100644 index 0000000..b348874 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh new file mode 100644 index 0000000..763252e --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh new file mode 100644 index 0000000..c549263 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh @@ -0,0 +1,16 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-14B_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/accelerate_config_14B.yaml b/examples/wanvideo/model_training/full/accelerate_config_14B.yaml new file mode 100644 index 0000000..3875a9d --- /dev/null +++ b/examples/wanvideo/model_training/full/accelerate_config_14B.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/wanvideo/model_training/full/run_test.py b/examples/wanvideo/model_training/full/run_test.py new file mode 100644 index 0000000..093becd --- /dev/null +++ b/examples/wanvideo/model_training/full/run_test.py @@ -0,0 +1,38 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} bash {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + # 1.3B + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/full"): + if file_name != "run_test.py" and "14B" not in file_name: + scripts.append(os.path.join("examples/wanvideo/model_training/full", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + + # 14B + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/full"): + if file_name != "run_test.py" and "14B" in file_name: + scripts.append(os.path.join("examples/wanvideo/model_training/full", file_name)) + for script in scripts: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"bash {script} > data/log/{log_file_name} 2>&1" + print(cmd, flush=True) + os.system(cmd) + + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh new file mode 100644 index 0000000..51ebfe4 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "motion_bucket_id" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh new file mode 100644 index 0000000..9a9622d --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh new file mode 100644 index 0000000..03c1f45 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh new file mode 100644 index 0000000..d5f509b --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh new file mode 100644 index 0000000..608df5f --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh new file mode 100644 index 0000000..37b2518 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh new file mode 100644 index 0000000..721d88f --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh new file mode 100644 index 0000000..1e7156d --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh new file mode 100644 index 0000000..5879f59 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh new file mode 100644 index 0000000..9fd30c4 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh new file mode 100644 index 0000000..3ead12c --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh new file mode 100644 index 0000000..40a8ad0 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh new file mode 100644 index 0000000..473d519 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-480P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh new file mode 100644 index 0000000..2037616 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh new file mode 100644 index 0000000..d16a287 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-1.3B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh new file mode 100644 index 0000000..1fb55ac --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh new file mode 100644 index 0000000..2bcb55b --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh @@ -0,0 +1,17 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh new file mode 100644 index 0000000..b565078 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -0,0 +1,17 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh new file mode 100644 index 0000000..633ea0e --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh @@ -0,0 +1,18 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-14B_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/run_test.py b/examples/wanvideo/model_training/lora/run_test.py new file mode 100644 index 0000000..ec0f9e2 --- /dev/null +++ b/examples/wanvideo/model_training/lora/run_test.py @@ -0,0 +1,25 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} bash {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/lora"): + if file_name != "run_test.py": + scripts.append(os.path.join("examples/wanvideo/model_training/lora", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py new file mode 100644 index 0000000..871bbd6 --- /dev/null +++ b/examples/wanvideo/model_training/train.py @@ -0,0 +1,110 @@ +import torch, os, json +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +class WanTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + ): + super().__init__() + # Load models + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + model_configs += [ModelConfig(path=path) for path in model_paths] + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) + + # Reset training scheduler + self.pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Add LoRA to the base models + if lora_base_model is not None: + model = self.add_lora_to_model( + getattr(self.pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank + ) + setattr(self.pipe, lora_base_model, model) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + + + def forward_preprocess(self, data): + # CFG-sensitive parameters + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {} + + # CFG-unsensitive parameters + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_video": data["video"], + "height": data["video"][0].size[1], + "width": data["video"][0].size[0], + "num_frames": len(data["video"]), + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "tiled": False, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "cfg_merge": False, + "vace_scale": 1, + } + + # Extra inputs + for extra_input in self.extra_inputs: + if extra_input == "input_image": + inputs_shared["input_image"] = data["video"][0] + elif extra_input == "end_image": + inputs_shared["end_image"] = data["video"][-1] + else: + inputs_shared[extra_input] = data[extra_input] + + # Pipeline units will automatically process the input parameters. + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + return {**inputs_shared, **inputs_posi} + + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.forward_preprocess(data) + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + loss = self.pipe.training_loss(**models, **inputs) + return loss + + +if __name__ == "__main__": + parser = wan_parser() + args = parser.parse_args() + dataset = VideoDataset(args=args) + model = WanTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + ) + launch_training_task(model, dataset, args=args) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py new file mode 100644 index 0000000..124749a --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-1.3b-speedcontrol-v1_full/epoch-1.safetensors") +pipe.motion_controller.load_state_dict(state_dict) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=50 +) +save_video(video, "video_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py new file mode 100644 index 0000000..41a67ed --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-FLF2V-14B-720P_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + end_image=video[80], + seed=0, tiled=True, + sigma_shift=16, +) +save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py new file mode 100644 index 0000000..6726e9c --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-1.3B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py new file mode 100644 index 0000000..3e1e6f3 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-1.3B-InP_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py new file mode 100644 index 0000000..08b0acb --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-14B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py new file mode 100644 index 0000000..d7e39d7 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-14B-InP_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py new file mode 100644 index 0000000..c5a210a --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + camera_control_direction="Left", camera_control_speed=0.0, + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py new file mode 100644 index 0000000..6497e1b --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py new file mode 100644 index 0000000..cd8ee20 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-InP_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py new file mode 100644 index 0000000..2de5102 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + camera_control_direction="Left", camera_control_speed=0.0, + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py new file mode 100644 index 0000000..0dd2516 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py new file mode 100644 index 0000000..7e944b0 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-InP_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py new file mode 100644 index 0000000..c1c8615 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-480P_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py new file mode 100644 index 0000000..a8610f3 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-720P_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py new file mode 100644 index 0000000..1420514 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py @@ -0,0 +1,25 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-T2V-1.3B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py new file mode 100644 index 0000000..a0107ae --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py @@ -0,0 +1,25 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-T2V-14B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py new file mode 100644 index 0000000..7db26e0 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/VACE-Wan2.1-1.3B-Preview_full/epoch-1.safetensors") +pipe.vace.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(49)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py new file mode 100644 index 0000000..5a371e7 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-VACE-1.3B_full/epoch-1.safetensors") +pipe.vace.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(49)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py new file mode 100644 index 0000000..5553471 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-VACE-14B_full/epoch-1.safetensors") +pipe.vace.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(17)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=17, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-VACE-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/run_test.py b/examples/wanvideo/model_training/validate_full/run_test.py new file mode 100644 index 0000000..a4e3203 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/run_test.py @@ -0,0 +1,25 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} python -u {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/validate_full"): + if file_name != "run_test.py": + scripts.append(os.path.join("examples/wanvideo/model_training/validate_full", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py new file mode 100644 index 0000000..167b871 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py @@ -0,0 +1,27 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-1.3b-speedcontrol-v1_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=50 +) +save_video(video, "video_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py new file mode 100644 index 0000000..cd68f0e --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-FLF2V-14B-720P_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + end_image=video[80], + seed=0, tiled=True, + sigma_shift=16, +) +save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py new file mode 100644 index 0000000..7270c38 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-1.3B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py new file mode 100644 index 0000000..c904dfa --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-1.3B-InP_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py new file mode 100644 index 0000000..8631d05 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-14B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py new file mode 100644 index 0000000..e020aac --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-14B-InP_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py new file mode 100644 index 0000000..3023692 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + camera_control_direction="Left", camera_control_speed=0.0, + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py new file mode 100644 index 0000000..ebcfd2f --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py new file mode 100644 index 0000000..99eb2b4 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-InP_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py new file mode 100644 index 0000000..0edb2d0 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + camera_control_direction="Left", camera_control_speed=0.0, + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py new file mode 100644 index 0000000..6b11098 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-14B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py new file mode 100644 index 0000000..35088fb --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-14B-InP_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py new file mode 100644 index 0000000..1687e36 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-480P_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py new file mode 100644 index 0000000..9893e26 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py new file mode 100644 index 0000000..7cb6c02 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-T2V-1.3B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py new file mode 100644 index 0000000..3b66a49 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-T2V-14B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py new file mode 100644 index 0000000..91cbf92 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.vace, "models/train/Wan2.1-VACE-1.3B-Preview_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(49)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py new file mode 100644 index 0000000..b5fd203 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.vace, "models/train/Wan2.1-VACE-1.3B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(49)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py new file mode 100644 index 0000000..bec5df3 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.vace, "models/train/Wan2.1-VACE-14B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(17)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=video, vace_reference_image=reference_image, num_frames=17, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-VACE-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/run_test.py b/examples/wanvideo/model_training/validate_lora/run_test.py new file mode 100644 index 0000000..367ee9d --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/run_test.py @@ -0,0 +1,25 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} python -u {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/validate_lora"): + if file_name != "run_test.py": + scripts.append(os.path.join("examples/wanvideo/model_training/validate_lora", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py deleted file mode 100644 index cd10096..0000000 --- a/examples/wanvideo/train_wan_t2v.py +++ /dev/null @@ -1,593 +0,0 @@ -import torch, os, imageio, argparse -from torchvision.transforms import v2 -from einops import rearrange -import lightning as pl -import pandas as pd -from diffsynth import WanVideoPipeline, ModelManager, load_state_dict -from peft import LoraConfig, inject_adapter_in_model -import torchvision -from PIL import Image -import numpy as np - - - -class TextVideoDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): - metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] - self.text = metadata["text"].to_list() - - self.max_num_frames = max_num_frames - self.frame_interval = frame_interval - self.num_frames = num_frames - self.height = height - self.width = width - self.is_i2v = is_i2v - - self.frame_process = v2.Compose([ - v2.CenterCrop(size=(height, width)), - v2.Resize(size=(height, width), antialias=True), - v2.ToTensor(), - v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - - - def crop_and_resize(self, image): - width, height = image.size - scale = max(self.width / width, self.height / height) - image = torchvision.transforms.functional.resize( - image, - (round(height*scale), round(width*scale)), - interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - return image - - - def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): - reader = imageio.get_reader(file_path) - if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: - reader.close() - return None - - frames = [] - first_frame = None - for frame_id in range(num_frames): - frame = reader.get_data(start_frame_id + frame_id * interval) - frame = Image.fromarray(frame) - frame = self.crop_and_resize(frame) - if first_frame is None: - first_frame = frame - frame = frame_process(frame) - frames.append(frame) - reader.close() - - frames = torch.stack(frames, dim=0) - frames = rearrange(frames, "T C H W -> C T H W") - - first_frame = v2.functional.center_crop(first_frame, output_size=(self.height, self.width)) - first_frame = np.array(first_frame) - - if self.is_i2v: - return frames, first_frame - else: - return frames - - - def load_video(self, file_path): - start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] - frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) - return frames - - - def is_image(self, file_path): - file_ext_name = file_path.split(".")[-1] - if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: - return True - return False - - - def load_image(self, file_path): - frame = Image.open(file_path).convert("RGB") - frame = self.crop_and_resize(frame) - first_frame = frame - frame = self.frame_process(frame) - frame = rearrange(frame, "C H W -> C 1 H W") - return frame - - - def __getitem__(self, data_id): - text = self.text[data_id] - path = self.path[data_id] - if self.is_image(path): - if self.is_i2v: - raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") - video = self.load_image(path) - else: - video = self.load_video(path) - if self.is_i2v: - video, first_frame = video - data = {"text": text, "video": video, "path": path, "first_frame": first_frame} - else: - data = {"text": text, "video": video, "path": path} - return data - - - def __len__(self): - return len(self.path) - - - -class LightningModelForDataProcess(pl.LightningModule): - def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): - super().__init__() - model_path = [text_encoder_path, vae_path] - if image_encoder_path is not None: - model_path.append(image_encoder_path) - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models(model_path) - self.pipe = WanVideoPipeline.from_model_manager(model_manager) - - self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} - - def test_step(self, batch, batch_idx): - text, video, path = batch["text"][0], batch["video"], batch["path"][0] - - self.pipe.device = self.device - if video is not None: - # prompt - prompt_emb = self.pipe.encode_prompt(text) - # video - video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) - latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] - # image - if "first_frame" in batch: - first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) - _, _, num_frames, height, width = video.shape - image_emb = self.pipe.encode_image(first_frame, None, num_frames, height, width) - else: - image_emb = {} - data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} - torch.save(data, path + ".tensors.pth") - - - -class TensorDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, steps_per_epoch): - metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] - print(len(self.path), "videos in metadata.") - self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] - print(len(self.path), "tensors cached in metadata.") - assert len(self.path) > 0 - - self.steps_per_epoch = steps_per_epoch - - - def __getitem__(self, index): - data_id = torch.randint(0, len(self.path), (1,))[0] - data_id = (data_id + index) % len(self.path) # For fixed seed. - path = self.path[data_id] - data = torch.load(path, weights_only=True, map_location="cpu") - return data - - - def __len__(self): - return self.steps_per_epoch - - - -class LightningModelForTrain(pl.LightningModule): - def __init__( - self, - dit_path, - learning_rate=1e-5, - lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", - use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, - pretrained_lora_path=None - ): - super().__init__() - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - if os.path.isfile(dit_path): - model_manager.load_models([dit_path]) - else: - dit_path = dit_path.split(",") - model_manager.load_models([dit_path]) - - self.pipe = WanVideoPipeline.from_model_manager(model_manager) - self.pipe.scheduler.set_timesteps(1000, training=True) - self.freeze_parameters() - if train_architecture == "lora": - self.add_lora_to_model( - self.pipe.denoising_model(), - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_target_modules=lora_target_modules, - init_lora_weights=init_lora_weights, - pretrained_lora_path=pretrained_lora_path, - ) - else: - self.pipe.denoising_model().requires_grad_(True) - - self.learning_rate = learning_rate - self.use_gradient_checkpointing = use_gradient_checkpointing - self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload - - - def freeze_parameters(self): - # Freeze parameters - self.pipe.requires_grad_(False) - self.pipe.eval() - self.pipe.denoising_model().train() - - - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None): - # Add LoRA to UNet - self.lora_alpha = lora_alpha - if init_lora_weights == "kaiming": - init_lora_weights = True - - lora_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - init_lora_weights=init_lora_weights, - target_modules=lora_target_modules.split(","), - ) - model = inject_adapter_in_model(lora_config, model) - for param in model.parameters(): - # Upcast LoRA parameters into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) - - # Lora pretrained lora weights - if pretrained_lora_path is not None: - state_dict = load_state_dict(pretrained_lora_path) - if state_dict_converter is not None: - state_dict = state_dict_converter(state_dict) - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - all_keys = [i for i, _ in model.named_parameters()] - num_updated_keys = len(all_keys) - len(missing_keys) - num_unexpected_keys = len(unexpected_keys) - print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") - - - def training_step(self, batch, batch_idx): - # Data - latents = batch["latents"].to(self.device) - prompt_emb = batch["prompt_emb"] - prompt_emb["context"] = prompt_emb["context"][0].to(self.device) - image_emb = batch["image_emb"] - if "clip_feature" in image_emb: - image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) - if "y" in image_emb: - image_emb["y"] = image_emb["y"][0].to(self.device) - - # Loss - self.pipe.device = self.device - noise = torch.randn_like(latents) - timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) - timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) - extra_input = self.pipe.prepare_extra_input(latents) - noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) - training_target = self.pipe.scheduler.training_target(latents, noise, timestep) - - # Compute loss - noise_pred = self.pipe.denoising_model()( - noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, - use_gradient_checkpointing=self.use_gradient_checkpointing, - use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload - ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) - loss = loss * self.pipe.scheduler.training_weight(timestep) - - # Record log - self.log("train_loss", loss, prog_bar=True) - return loss - - - def configure_optimizers(self): - trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) - optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) - return optimizer - - - def on_save_checkpoint(self, checkpoint): - checkpoint.clear() - trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) - trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) - state_dict = self.pipe.denoising_model().state_dict() - lora_state_dict = {} - for name, param in state_dict.items(): - if name in trainable_param_names: - lora_state_dict[name] = param - checkpoint.update(lora_state_dict) - - - -def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--task", - type=str, - default="data_process", - required=True, - choices=["data_process", "train"], - help="Task. `data_process` or `train`.", - ) - parser.add_argument( - "--dataset_path", - type=str, - default=None, - required=True, - help="The path of the Dataset.", - ) - parser.add_argument( - "--output_path", - type=str, - default="./", - help="Path to save the model.", - ) - parser.add_argument( - "--text_encoder_path", - type=str, - default=None, - help="Path of text encoder.", - ) - parser.add_argument( - "--image_encoder_path", - type=str, - default=None, - help="Path of image encoder.", - ) - parser.add_argument( - "--vae_path", - type=str, - default=None, - help="Path of VAE.", - ) - parser.add_argument( - "--dit_path", - type=str, - default=None, - help="Path of DiT.", - ) - parser.add_argument( - "--tiled", - default=False, - action="store_true", - help="Whether enable tile encode in VAE. This option can reduce VRAM required.", - ) - parser.add_argument( - "--tile_size_height", - type=int, - default=34, - help="Tile size (height) in VAE.", - ) - parser.add_argument( - "--tile_size_width", - type=int, - default=34, - help="Tile size (width) in VAE.", - ) - parser.add_argument( - "--tile_stride_height", - type=int, - default=18, - help="Tile stride (height) in VAE.", - ) - parser.add_argument( - "--tile_stride_width", - type=int, - default=16, - help="Tile stride (width) in VAE.", - ) - parser.add_argument( - "--steps_per_epoch", - type=int, - default=500, - help="Number of steps per epoch.", - ) - parser.add_argument( - "--num_frames", - type=int, - default=81, - help="Number of frames.", - ) - parser.add_argument( - "--height", - type=int, - default=480, - help="Image height.", - ) - parser.add_argument( - "--width", - type=int, - default=832, - help="Image width.", - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=1, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-5, - help="Learning rate.", - ) - parser.add_argument( - "--accumulate_grad_batches", - type=int, - default=1, - help="The number of batches in gradient accumulation.", - ) - parser.add_argument( - "--max_epochs", - type=int, - default=1, - help="Number of epochs.", - ) - parser.add_argument( - "--lora_target_modules", - type=str, - default="q,k,v,o,ffn.0,ffn.2", - help="Layers with LoRA modules.", - ) - parser.add_argument( - "--init_lora_weights", - type=str, - default="kaiming", - choices=["gaussian", "kaiming"], - help="The initializing method of LoRA weight.", - ) - parser.add_argument( - "--training_strategy", - type=str, - default="auto", - choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], - help="Training strategy", - ) - parser.add_argument( - "--lora_rank", - type=int, - default=4, - help="The dimension of the LoRA update matrices.", - ) - parser.add_argument( - "--lora_alpha", - type=float, - default=4.0, - help="The weight of the LoRA update matrices.", - ) - parser.add_argument( - "--use_gradient_checkpointing", - default=False, - action="store_true", - help="Whether to use gradient checkpointing.", - ) - parser.add_argument( - "--use_gradient_checkpointing_offload", - default=False, - action="store_true", - help="Whether to use gradient checkpointing offload.", - ) - parser.add_argument( - "--train_architecture", - type=str, - default="lora", - choices=["lora", "full"], - help="Model structure to train. LoRA training or full training.", - ) - parser.add_argument( - "--pretrained_lora_path", - type=str, - default=None, - help="Pretrained LoRA path. Required if the training is resumed.", - ) - parser.add_argument( - "--use_swanlab", - default=False, - action="store_true", - help="Whether to use SwanLab logger.", - ) - parser.add_argument( - "--swanlab_mode", - default=None, - help="SwanLab mode (cloud or local).", - ) - args = parser.parse_args() - return args - - -def data_process(args): - dataset = TextVideoDataset( - args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), - max_num_frames=args.num_frames, - frame_interval=1, - num_frames=args.num_frames, - height=args.height, - width=args.width, - is_i2v=args.image_encoder_path is not None - ) - dataloader = torch.utils.data.DataLoader( - dataset, - shuffle=False, - batch_size=1, - num_workers=args.dataloader_num_workers - ) - model = LightningModelForDataProcess( - text_encoder_path=args.text_encoder_path, - image_encoder_path=args.image_encoder_path, - vae_path=args.vae_path, - tiled=args.tiled, - tile_size=(args.tile_size_height, args.tile_size_width), - tile_stride=(args.tile_stride_height, args.tile_stride_width), - ) - trainer = pl.Trainer( - accelerator="gpu", - devices="auto", - default_root_dir=args.output_path, - ) - trainer.test(model, dataloader) - - -def train(args): - dataset = TensorDataset( - args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), - steps_per_epoch=args.steps_per_epoch, - ) - dataloader = torch.utils.data.DataLoader( - dataset, - shuffle=True, - batch_size=1, - num_workers=args.dataloader_num_workers - ) - model = LightningModelForTrain( - dit_path=args.dit_path, - learning_rate=args.learning_rate, - train_architecture=args.train_architecture, - lora_rank=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_target_modules=args.lora_target_modules, - init_lora_weights=args.init_lora_weights, - use_gradient_checkpointing=args.use_gradient_checkpointing, - use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, - pretrained_lora_path=args.pretrained_lora_path, - ) - if args.use_swanlab: - from swanlab.integration.pytorch_lightning import SwanLabLogger - swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} - swanlab_config.update(vars(args)) - swanlab_logger = SwanLabLogger( - project="wan", - name="wan", - config=swanlab_config, - mode=args.swanlab_mode, - logdir=os.path.join(args.output_path, "swanlog"), - ) - logger = [swanlab_logger] - else: - logger = None - trainer = pl.Trainer( - max_epochs=args.max_epochs, - accelerator="gpu", - devices="auto", - precision="bf16", - strategy=args.training_strategy, - default_root_dir=args.output_path, - accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=logger, - ) - trainer.fit(model, dataloader) - - -if __name__ == '__main__': - args = parse_args() - if args.task == "data_process": - data_process(args) - elif args.task == "train": - train(args) diff --git a/examples/wanvideo/wan_14B_flf2v.py b/examples/wanvideo/wan_14B_flf2v.py deleted file mode 100644 index 23109df..0000000 --- a/examples/wanvideo/wan_14B_flf2v.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("Wan-AI/Wan2.1-FLF2V-14B-720P", local_dir="models/Wan-AI/Wan2.1-FLF2V-14B-720P") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - ["models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], - torch_dtype=torch.float32, # Image Encoder is loaded with float32 -) -model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00001-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00002-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00003-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00004-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00005-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00006-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00007-of-00007.safetensors", - ], - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Download example image -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] -) - -# First and last frame to video -video = pipe( - prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=30, - input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), - end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), - height=960, width=960, - seed=1, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_image_to_video.py b/examples/wanvideo/wan_14b_image_to_video.py deleted file mode 100644 index 91894ae..0000000 --- a/examples/wanvideo/wan_14b_image_to_video.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-I2V-14B-480P") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - ["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], - torch_dtype=torch.float32, # Image Encoder is loaded with float32 -) -model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors", - ], - "models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. - -# Download example image -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# Image-to-video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=image, - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py deleted file mode 100644 index 654565d..0000000 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download - - -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", - ], - "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", - ], - torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. - -# Text-to-video -video = pipe( - prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video1.mp4", fps=25, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py deleted file mode 100644 index 77c230c..0000000 --- a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py +++ /dev/null @@ -1,149 +0,0 @@ -import torch -import lightning as pl -from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.tensor.parallel import parallelize_module -from lightning.pytorch.strategies import ModelParallelStrategy -from diffsynth import ModelManager, WanVideoPipeline, save_video -from tqdm import tqdm -from modelscope import snapshot_download - - - -class ToyDataset(torch.utils.data.Dataset): - def __init__(self, tasks=[]): - self.tasks = tasks - - def __getitem__(self, data_id): - return self.tasks[data_id] - - def __len__(self): - return len(self.tasks) - - -class LitModel(pl.LightningModule): - def __init__(self): - super().__init__() - model_manager = ModelManager(device="cpu") - model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", - ], - "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, - ) - self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") - - def configure_model(self): - tp_mesh = self.device_mesh["tensor_parallel"] - plan = { - "text_embedding.0": ColwiseParallel(), - "text_embedding.2": RowwiseParallel(), - "time_projection.1": ColwiseParallel(output_layouts=Replicate()), - "text_embedding.0": ColwiseParallel(), - "text_embedding.2": RowwiseParallel(), - "blocks.0": PrepareModuleInput( - input_layouts=(Replicate(), None, None, None), - desired_input_layouts=(Replicate(), None, None, None), - ), - "head": PrepareModuleInput( - input_layouts=(Replicate(), None), - desired_input_layouts=(Replicate(), None), - use_local_output=True, - ) - } - self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan) - for block_id, block in enumerate(self.pipe.dit.blocks): - layer_tp_plan = { - "self_attn": PrepareModuleInput( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Shard(1), Shard(0)), - ), - "self_attn.q": SequenceParallel(), - "self_attn.k": SequenceParallel(), - "self_attn.v": SequenceParallel(), - "self_attn.norm_q": SequenceParallel(), - "self_attn.norm_k": SequenceParallel(), - "self_attn.attn": PrepareModuleInput( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(2), Shard(2), Shard(2)), - ), - "self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()), - - "cross_attn": PrepareModuleInput( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Shard(1), Replicate()), - ), - "cross_attn.q": SequenceParallel(), - "cross_attn.k": SequenceParallel(), - "cross_attn.v": SequenceParallel(), - "cross_attn.norm_q": SequenceParallel(), - "cross_attn.norm_k": SequenceParallel(), - "cross_attn.attn": PrepareModuleInput( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(2), Shard(2), Shard(2)), - ), - "cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False), - - "ffn.0": ColwiseParallel(input_layouts=Shard(1)), - "ffn.2": RowwiseParallel(output_layouts=Replicate()), - - "norm1": SequenceParallel(use_local_output=True), - "norm2": SequenceParallel(use_local_output=True), - "norm3": SequenceParallel(use_local_output=True), - "gate": PrepareModuleInput( - input_layouts=(Shard(1), Replicate(), Replicate()), - desired_input_layouts=(Replicate(), Replicate(), Replicate()), - ) - } - parallelize_module( - module=block, - device_mesh=tp_mesh, - parallelize_plan=layer_tp_plan, - ) - - - def test_step(self, batch): - data = batch[0] - data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x - output_path = data.pop("output_path") - with torch.no_grad(), torch.inference_mode(False): - video = self.pipe(**data) - if self.local_rank == 0: - save_video(video, output_path, fps=15, quality=5) - - -if __name__ == "__main__": - snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") - dataloader = torch.utils.data.DataLoader( - ToyDataset([ - { - "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - "num_inference_steps": 50, - "seed": 0, - "tiled": False, - "output_path": "video1.mp4", - }, - { - "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - "num_inference_steps": 50, - "seed": 1, - "tiled": False, - "output_path": "video2.mp4", - }, - ]), - collate_fn=lambda x: x - ) - model = LitModel() - trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy()) - trainer.test(model, dataloader) \ No newline at end of file diff --git a/examples/wanvideo/wan_14b_text_to_video_usp.py b/examples/wanvideo/wan_14b_text_to_video_usp.py deleted file mode 100644 index 8837294..0000000 --- a/examples/wanvideo/wan_14b_text_to_video_usp.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download -import torch.distributed as dist - - -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", - ], - "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", - ], - torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. -) - -dist.init_process_group( - backend="nccl", - init_method="env://", -) -from xfuser.core.distributed import (initialize_model_parallel, - init_distributed_environment) -init_distributed_environment( - rank=dist.get_rank(), world_size=dist.get_world_size()) - -initialize_model_parallel( - sequence_parallel_degree=dist.get_world_size(), - ring_degree=1, - ulysses_degree=dist.get_world_size(), -) -torch.cuda.set_device(dist.get_rank()) - -pipe = WanVideoPipeline.from_model_manager(model_manager, - torch_dtype=torch.bfloat16, - device=f"cuda:{dist.get_rank()}", - use_usp=True if dist.get_world_size() > 1 else False) -pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. - -# Text-to-video -video = pipe( - prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - seed=0, tiled=True -) -if dist.get_rank() == 0: - save_video(video, "video1.mp4", fps=25, quality=5) diff --git a/examples/wanvideo/wan_fun_control.py b/examples/wanvideo/wan_fun_control.py deleted file mode 100644 index e2c4d0c..0000000 --- a/examples/wanvideo/wan_fun_control.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors", - "models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth", - "models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth", - "models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Download example video -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/control_video.mp4" -) - -# Control-to-video -control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) -video = pipe( - prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - control_video=control_video, height=832, width=576, num_frames=49, - seed=1, tiled=True -) -save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/requirements.txt b/requirements.txt index 63a871b..92d8b48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch>=2.0.0 torchvision cupy-cuda12x -transformers==4.46.2 +transformers controlnet-aux==0.0.7 imageio imageio[ffmpeg] @@ -11,3 +11,4 @@ sentencepiece protobuf modelscope ftfy +pynvml