mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
@@ -58,7 +58,7 @@ from ..models.stepvideo_dit import StepVideoModel
|
|||||||
from ..models.wan_video_dit import WanModel
|
from ..models.wan_video_dit import WanModel
|
||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
from ..models.wan_video_vae import WanVideoVAE
|
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
from ..models.wan_video_vace import VaceWanModel
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
|
|
||||||
@@ -141,6 +141,8 @@ model_loader_configs = [
|
|||||||
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||||
@@ -148,6 +150,7 @@ model_loader_configs = [
|
|||||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
|
(None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"),
|
||||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||||
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||||
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
||||||
|
|||||||
@@ -426,7 +426,7 @@ class ModelManager:
|
|||||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
|
||||||
fetched_models = []
|
fetched_models = []
|
||||||
fetched_model_paths = []
|
fetched_model_paths = []
|
||||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||||
@@ -440,12 +440,25 @@ class ModelManager:
|
|||||||
return None
|
return None
|
||||||
if len(fetched_models) == 1:
|
if len(fetched_models) == 1:
|
||||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||||
|
model = fetched_models[0]
|
||||||
|
path = fetched_model_paths[0]
|
||||||
else:
|
else:
|
||||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
if index is None:
|
||||||
|
model = fetched_models[0]
|
||||||
|
path = fetched_model_paths[0]
|
||||||
|
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||||
|
elif isinstance(index, int):
|
||||||
|
model = fetched_models[:index]
|
||||||
|
path = fetched_model_paths[:index]
|
||||||
|
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
|
||||||
|
else:
|
||||||
|
model = fetched_models
|
||||||
|
path = fetched_model_paths
|
||||||
|
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
|
||||||
if require_model_path:
|
if require_model_path:
|
||||||
return fetched_models[0], fetched_model_paths[0]
|
return model, path
|
||||||
else:
|
else:
|
||||||
return fetched_models[0]
|
return model
|
||||||
|
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
|
|||||||
@@ -212,9 +212,16 @@ class DiTBlock(nn.Module):
|
|||||||
self.gate = GateModule()
|
self.gate = GateModule()
|
||||||
|
|
||||||
def forward(self, x, context, t_mod, freqs):
|
def forward(self, x, context, t_mod, freqs):
|
||||||
|
has_seq = len(t_mod.shape) == 4
|
||||||
|
chunk_dim = 2 if has_seq else 1
|
||||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
|
||||||
|
if has_seq:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
|
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
|
||||||
|
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
|
||||||
|
)
|
||||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||||
x = x + self.cross_attn(self.norm3(x), context)
|
x = x + self.cross_attn(self.norm3(x), context)
|
||||||
@@ -253,8 +260,12 @@ class Head(nn.Module):
|
|||||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||||
|
|
||||||
def forward(self, x, t_mod):
|
def forward(self, x, t_mod):
|
||||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
if len(t_mod.shape) == 3:
|
||||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
|
||||||
|
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
|
||||||
|
else:
|
||||||
|
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||||
|
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -276,12 +287,20 @@ class WanModel(torch.nn.Module):
|
|||||||
has_ref_conv: bool = False,
|
has_ref_conv: bool = False,
|
||||||
add_control_adapter: bool = False,
|
add_control_adapter: bool = False,
|
||||||
in_dim_control_adapter: int = 24,
|
in_dim_control_adapter: int = 24,
|
||||||
|
seperated_timestep: bool = False,
|
||||||
|
require_vae_embedding: bool = True,
|
||||||
|
require_clip_embedding: bool = True,
|
||||||
|
fuse_vae_embedding_in_latents: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.freq_dim = freq_dim
|
self.freq_dim = freq_dim
|
||||||
self.has_image_input = has_image_input
|
self.has_image_input = has_image_input
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.seperated_timestep = seperated_timestep
|
||||||
|
self.require_vae_embedding = require_vae_embedding
|
||||||
|
self.require_clip_embedding = require_clip_embedding
|
||||||
|
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||||
|
|
||||||
self.patch_embedding = nn.Conv3d(
|
self.patch_embedding = nn.Conv3d(
|
||||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
@@ -659,6 +678,41 @@ class WanModelStateDictConverter:
|
|||||||
"add_control_adapter": True,
|
"add_control_adapter": True,
|
||||||
"in_dim_control_adapter": 24,
|
"in_dim_control_adapter": 24,
|
||||||
}
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
|
||||||
|
# Wan-AI/Wan2.2-TI2V-5B
|
||||||
|
config = {
|
||||||
|
"has_image_input": False,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 48,
|
||||||
|
"dim": 3072,
|
||||||
|
"ffn_dim": 14336,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 48,
|
||||||
|
"num_heads": 24,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"seperated_timestep": True,
|
||||||
|
"require_clip_embedding": False,
|
||||||
|
"require_vae_embedding": False,
|
||||||
|
"fuse_vae_embedding_in_latents": True,
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||||
|
# Wan-AI/Wan2.2-I2V-A14B
|
||||||
|
config = {
|
||||||
|
"has_image_input": False,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 36,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"require_clip_embedding": False,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
return state_dict, config
|
return state_dict, config
|
||||||
|
|||||||
@@ -195,6 +195,75 @@ class Resample(nn.Module):
|
|||||||
nn.init.zeros_(conv.bias.data)
|
nn.init.zeros_(conv.bias.data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(x, patch_size):
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(x,
|
||||||
|
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(x, patch_size):
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(x,
|
||||||
|
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Resample38(Resample):
|
||||||
|
|
||||||
|
def __init__(self, dim, mode):
|
||||||
|
assert mode in (
|
||||||
|
"none",
|
||||||
|
"upsample2d",
|
||||||
|
"upsample3d",
|
||||||
|
"downsample2d",
|
||||||
|
"downsample3d",
|
||||||
|
)
|
||||||
|
super(Resample, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
# layers
|
||||||
|
if mode == "upsample2d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||||
|
nn.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
elif mode == "upsample3d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||||
|
nn.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
elif mode == "downsample2d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
||||||
|
)
|
||||||
|
elif mode == "downsample3d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
||||||
|
)
|
||||||
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||||
@@ -273,6 +342,178 @@ class AttentionBlock(nn.Module):
|
|||||||
return x + identity
|
return x + identity
|
||||||
|
|
||||||
|
|
||||||
|
class AvgDown3D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert in_channels * self.factor % out_channels == 0
|
||||||
|
self.group_size = in_channels * self.factor // out_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||||
|
pad = (0, 0, 0, 0, pad_t, 0)
|
||||||
|
x = F.pad(x, pad)
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
T // self.factor_t,
|
||||||
|
self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C * self.factor,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
self.out_channels,
|
||||||
|
self.group_size,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.mean(dim=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DupUp3D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert out_channels * self.factor % in_channels == 0
|
||||||
|
self.repeats = out_channels * self.factor // in_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||||
|
x = x.repeat_interleave(self.repeats, dim=1)
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
self.factor_t,
|
||||||
|
self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
x.size(2),
|
||||||
|
x.size(3),
|
||||||
|
x.size(4),
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
x.size(2) * self.factor_t,
|
||||||
|
x.size(4) * self.factor_s,
|
||||||
|
x.size(6) * self.factor_s,
|
||||||
|
)
|
||||||
|
if first_chunk:
|
||||||
|
x = x[:, :, self.factor_t - 1 :, :, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Down_ResidualBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Shortcut path with downsample
|
||||||
|
self.avg_shortcut = AvgDown3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_downsample else 1,
|
||||||
|
factor_s=2 if down_flag else 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main path with residual blocks and downsample
|
||||||
|
downsamples = []
|
||||||
|
for _ in range(mult):
|
||||||
|
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# Add the final downsample block
|
||||||
|
if down_flag:
|
||||||
|
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||||
|
downsamples.append(Resample38(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
x_copy = x.clone()
|
||||||
|
for module in self.downsamples:
|
||||||
|
x = module(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
return x + self.avg_shortcut(x_copy)
|
||||||
|
|
||||||
|
|
||||||
|
class Up_ResidualBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Shortcut path with upsample
|
||||||
|
if up_flag:
|
||||||
|
self.avg_shortcut = DupUp3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_upsample else 1,
|
||||||
|
factor_s=2 if up_flag else 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.avg_shortcut = None
|
||||||
|
|
||||||
|
# Main path with residual blocks and upsample
|
||||||
|
upsamples = []
|
||||||
|
for _ in range(mult):
|
||||||
|
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# Add the final upsample block
|
||||||
|
if up_flag:
|
||||||
|
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||||
|
upsamples.append(Resample38(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
|
x_main = x.clone()
|
||||||
|
for module in self.upsamples:
|
||||||
|
x_main = module(x_main, feat_cache, feat_idx)
|
||||||
|
if self.avg_shortcut is not None:
|
||||||
|
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||||
|
return x_main + x_shortcut
|
||||||
|
else:
|
||||||
|
return x_main
|
||||||
|
|
||||||
|
|
||||||
class Encoder3d(nn.Module):
|
class Encoder3d(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -376,6 +617,122 @@ class Encoder3d(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder3d_38(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[False, True, True],
|
||||||
|
dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# downsample blocks
|
||||||
|
downsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
t_down_flag = (
|
||||||
|
temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||||
|
)
|
||||||
|
downsamples.append(
|
||||||
|
Down_ResidualBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mult=num_res_blocks,
|
||||||
|
temperal_downsample=t_down_flag,
|
||||||
|
down_flag=i != len(dim_mult) - 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scale /= 2.0
|
||||||
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout),
|
||||||
|
AttentionBlock(out_dim),
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# # output blocks
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
RMS_norm(out_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
## downsamples
|
||||||
|
for layer in self.downsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## middle
|
||||||
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :]
|
||||||
|
.unsqueeze(2)
|
||||||
|
.to(cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Decoder3d(nn.Module):
|
class Decoder3d(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -481,10 +838,112 @@ class Decoder3d(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder3d_38(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_upsample=[False, True, True],
|
||||||
|
dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_upsample = temperal_upsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
||||||
|
AttentionBlock(dims[0]),
|
||||||
|
ResidualBlock(dims[0], dims[0], dropout))
|
||||||
|
|
||||||
|
# upsample blocks
|
||||||
|
upsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||||
|
upsamples.append(
|
||||||
|
Up_ResidualBlock(in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mult=num_res_blocks + 1,
|
||||||
|
temperal_upsample=t_up_flag,
|
||||||
|
up_flag=i != len(dim_mult) - 1))
|
||||||
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, 12, 3, padding=1))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
for layer in self.middle:
|
||||||
|
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## upsamples
|
||||||
|
for layer in self.upsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :]
|
||||||
|
.unsqueeze(2)
|
||||||
|
.to(cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def count_conv3d(model):
|
def count_conv3d(model):
|
||||||
count = 0
|
count = 0
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if check_is_instance(m, CausalConv3d):
|
if isinstance(m, CausalConv3d):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@@ -798,3 +1257,118 @@ class WanVideoVAEStateDictConverter:
|
|||||||
for name in state_dict:
|
for name in state_dict:
|
||||||
state_dict_['model.' + name] = state_dict[name]
|
state_dict_['model.' + name] = state_dict[name]
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
class VideoVAE38_(VideoVAE_):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=160,
|
||||||
|
z_dim=48,
|
||||||
|
dec_dim=256,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[False, True, True],
|
||||||
|
dropout=0.0):
|
||||||
|
super(VideoVAE_, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
|
# modules
|
||||||
|
self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||||
|
attn_scales, self.temperal_downsample, dropout)
|
||||||
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
|
self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks,
|
||||||
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, x, scale):
|
||||||
|
self.clear_cache()
|
||||||
|
x = patchify(x, patch_size=2)
|
||||||
|
t = x.shape[2]
|
||||||
|
iter_ = 1 + (t - 1) // 4
|
||||||
|
for i in range(iter_):
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.encoder(x[:, :, :1, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx)
|
||||||
|
else:
|
||||||
|
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
|
if isinstance(scale[0], torch.Tensor):
|
||||||
|
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
||||||
|
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||||
|
1, self.z_dim, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
||||||
|
mu = (mu - scale[0]) * scale[1]
|
||||||
|
self.clear_cache()
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
def decode(self, z, scale):
|
||||||
|
self.clear_cache()
|
||||||
|
if isinstance(scale[0], torch.Tensor):
|
||||||
|
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||||
|
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||||
|
1, self.z_dim, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||||
|
z = z / scale[1] + scale[0]
|
||||||
|
iter_ = z.shape[2]
|
||||||
|
x = self.conv2(z)
|
||||||
|
for i in range(iter_):
|
||||||
|
self._conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx,
|
||||||
|
first_chunk=True)
|
||||||
|
else:
|
||||||
|
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
out = unpatchify(out, patch_size=2)
|
||||||
|
self.clear_cache()
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoVAE38(WanVideoVAE):
|
||||||
|
|
||||||
|
def __init__(self, z_dim=48, dim=160):
|
||||||
|
super(WanVideoVAE, self).__init__()
|
||||||
|
|
||||||
|
mean = [
|
||||||
|
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||||
|
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||||
|
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||||
|
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||||
|
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||||
|
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667
|
||||||
|
]
|
||||||
|
std = [
|
||||||
|
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||||
|
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||||
|
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||||
|
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||||
|
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||||
|
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||||
|
]
|
||||||
|
self.mean = torch.tensor(mean)
|
||||||
|
self.std = torch.tensor(std)
|
||||||
|
self.scale = [self.mean, 1.0 / self.std]
|
||||||
|
|
||||||
|
# init model
|
||||||
|
self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
|
||||||
|
self.upsampling_factor = 16
|
||||||
|
|||||||
@@ -39,17 +39,21 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.text_encoder: WanTextEncoder = None
|
self.text_encoder: WanTextEncoder = None
|
||||||
self.image_encoder: WanImageEncoder = None
|
self.image_encoder: WanImageEncoder = None
|
||||||
self.dit: WanModel = None
|
self.dit: WanModel = None
|
||||||
|
self.dit2: WanModel = None
|
||||||
self.vae: WanVideoVAE = None
|
self.vae: WanVideoVAE = None
|
||||||
self.motion_controller: WanMotionControllerModel = None
|
self.motion_controller: WanMotionControllerModel = None
|
||||||
self.vace: VaceWanModel = None
|
self.vace: VaceWanModel = None
|
||||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||||
|
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
||||||
self.unit_runner = PipelineUnitRunner()
|
self.unit_runner = PipelineUnitRunner()
|
||||||
self.units = [
|
self.units = [
|
||||||
WanVideoUnit_ShapeChecker(),
|
WanVideoUnit_ShapeChecker(),
|
||||||
WanVideoUnit_NoiseInitializer(),
|
WanVideoUnit_NoiseInitializer(),
|
||||||
WanVideoUnit_InputVideoEmbedder(),
|
WanVideoUnit_InputVideoEmbedder(),
|
||||||
WanVideoUnit_PromptEmbedder(),
|
WanVideoUnit_PromptEmbedder(),
|
||||||
WanVideoUnit_ImageEmbedder(),
|
WanVideoUnit_ImageEmbedderVAE(),
|
||||||
|
WanVideoUnit_ImageEmbedderCLIP(),
|
||||||
|
WanVideoUnit_ImageEmbedderFused(),
|
||||||
WanVideoUnit_FunControl(),
|
WanVideoUnit_FunControl(),
|
||||||
WanVideoUnit_FunReference(),
|
WanVideoUnit_FunReference(),
|
||||||
WanVideoUnit_FunCameraControl(),
|
WanVideoUnit_FunCameraControl(),
|
||||||
@@ -69,7 +73,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def training_loss(self, **inputs):
|
def training_loss(self, **inputs):
|
||||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||||
@@ -141,6 +147,37 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
),
|
),
|
||||||
vram_limit=vram_limit,
|
vram_limit=vram_limit,
|
||||||
)
|
)
|
||||||
|
if self.dit2 is not None:
|
||||||
|
dtype = next(iter(self.dit2.parameters())).dtype
|
||||||
|
device = "cpu" if vram_limit is not None else self.device
|
||||||
|
enable_vram_management(
|
||||||
|
self.dit2,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
|
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||||
|
RMSNorm: AutoWrappedModule,
|
||||||
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device=device,
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
max_num_param=num_persistent_param_in_dit,
|
||||||
|
overflow_module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
vram_limit=vram_limit,
|
||||||
|
)
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
dtype = next(iter(self.vae.parameters())).dtype
|
dtype = next(iter(self.vae.parameters())).dtype
|
||||||
enable_vram_management(
|
enable_vram_management(
|
||||||
@@ -239,6 +276,10 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
for block in self.dit.blocks:
|
for block in self.dit.blocks:
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
||||||
|
if self.dit2 is not None:
|
||||||
|
for block in self.dit2.blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
self.use_unified_sequence_parallel = True
|
self.use_unified_sequence_parallel = True
|
||||||
|
|
||||||
@@ -283,10 +324,18 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Load models
|
# Load models
|
||||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||||
|
num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"])
|
||||||
|
if num_dits == 2:
|
||||||
|
pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1]
|
||||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||||
|
|
||||||
|
# Size division factor
|
||||||
|
if pipe.vae is not None:
|
||||||
|
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
||||||
|
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
||||||
@@ -333,6 +382,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
cfg_scale: Optional[float] = 5.0,
|
cfg_scale: Optional[float] = 5.0,
|
||||||
cfg_merge: Optional[bool] = False,
|
cfg_merge: Optional[bool] = False,
|
||||||
|
# Boundary
|
||||||
|
switch_DiT_boundary: Optional[float] = 0.875,
|
||||||
# Scheduler
|
# Scheduler
|
||||||
num_inference_steps: Optional[int] = 50,
|
num_inference_steps: Optional[int] = 50,
|
||||||
sigma_shift: Optional[float] = 5.0,
|
sigma_shift: Optional[float] = 5.0,
|
||||||
@@ -385,8 +436,14 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.load_models_to_device(self.in_iteration_models)
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
# Switch DiT if necessary
|
||||||
|
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||||
|
self.load_models_to_device(self.in_iteration_models_2)
|
||||||
|
models["dit"] = self.dit2
|
||||||
|
|
||||||
|
# Timestep
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
@@ -400,6 +457,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
||||||
|
if "first_frame_latents" in inputs_shared:
|
||||||
|
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
||||||
|
|
||||||
# VACE (TODO: remove it)
|
# VACE (TODO: remove it)
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None:
|
||||||
@@ -433,7 +492,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
length = (num_frames - 1) // 4 + 1
|
length = (num_frames - 1) // 4 + 1
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None:
|
||||||
length += 1
|
length += 1
|
||||||
noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device)
|
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
|
||||||
|
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None:
|
||||||
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
||||||
return {"noise": noise}
|
return {"noise": noise}
|
||||||
@@ -482,6 +542,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
|
|||||||
|
|
||||||
|
|
||||||
class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||||
|
"""
|
||||||
|
Deprecated
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||||
@@ -489,7 +552,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
if input_image is None:
|
if input_image is None or pipe.image_encoder is None:
|
||||||
return {}
|
return {}
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||||
@@ -517,13 +580,90 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
|||||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
return {"clip_feature": clip_context, "y": y}
|
return {"clip_feature": clip_context, "y": y}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "end_image", "height", "width"),
|
||||||
|
onload_model_names=("image_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
|
||||||
|
if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||||
|
clip_context = pipe.image_encoder.encode_image([image])
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||||
|
if pipe.dit.has_image_pos_emb:
|
||||||
|
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
|
||||||
|
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"clip_feature": clip_context}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
|
if input_image is None or not pipe.dit.require_vae_embedding:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||||
|
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||||
|
msk[:, 1:] = 0
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||||
|
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||||
|
msk[:, -1:] = 1
|
||||||
|
else:
|
||||||
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||||
|
|
||||||
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
|
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
y = torch.concat([msk, y])
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"y": y}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
|
||||||
|
"""
|
||||||
|
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
|
||||||
|
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
|
||||||
|
z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
latents[:, :, 0: 1] = z
|
||||||
|
return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WanVideoUnit_FunControl(PipelineUnit):
|
class WanVideoUnit_FunControl(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
|
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
|
||||||
onload_model_names=("vae")
|
onload_model_names=("vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
|
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
|
||||||
@@ -547,7 +687,7 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("reference_image", "height", "width", "reference_image"),
|
input_params=("reference_image", "height", "width", "reference_image"),
|
||||||
onload_model_names=("vae")
|
onload_model_names=("vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
|
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
|
||||||
@@ -832,6 +972,7 @@ def model_fn_wan_video(
|
|||||||
use_gradient_checkpointing: bool = False,
|
use_gradient_checkpointing: bool = False,
|
||||||
use_gradient_checkpointing_offload: bool = False,
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
control_camera_latents_input = None,
|
control_camera_latents_input = None,
|
||||||
|
fuse_vae_embedding_in_latents: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||||
@@ -865,9 +1006,20 @@ def model_fn_wan_video(
|
|||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
get_sp_group)
|
get_sp_group)
|
||||||
|
|
||||||
|
# Timestep
|
||||||
|
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
|
||||||
|
timestep = torch.concat([
|
||||||
|
torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
|
||||||
|
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
|
||||||
|
]).flatten()
|
||||||
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
|
||||||
|
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
||||||
|
else:
|
||||||
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||||
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||||
|
|
||||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
# Motion Controller
|
||||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
|
||||||
if motion_bucket_id is not None and motion_controller is not None:
|
if motion_bucket_id is not None and motion_controller is not None:
|
||||||
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||||
context = dit.text_embedding(context)
|
context = dit.text_embedding(context)
|
||||||
@@ -878,15 +1030,16 @@ def model_fn_wan_video(
|
|||||||
x = torch.concat([x] * context.shape[0], dim=0)
|
x = torch.concat([x] * context.shape[0], dim=0)
|
||||||
if timestep.shape[0] != context.shape[0]:
|
if timestep.shape[0] != context.shape[0]:
|
||||||
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
||||||
|
|
||||||
if dit.has_image_input:
|
# Image Embedding
|
||||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
if y is not None and dit.require_vae_embedding:
|
||||||
|
x = torch.cat([x, y], dim=1)
|
||||||
|
if clip_feature is not None and dit.require_clip_embedding:
|
||||||
clip_embdding = dit.img_emb(clip_feature)
|
clip_embdding = dit.img_emb(clip_feature)
|
||||||
context = torch.cat([clip_embdding, context], dim=1)
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
# Add camera control
|
# Add camera control
|
||||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||||
|
|
||||||
|
|
||||||
# Reference image
|
# Reference image
|
||||||
if reference_latents is not None:
|
if reference_latents is not None:
|
||||||
|
|||||||
@@ -434,6 +434,8 @@ def wan_parser():
|
|||||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||||
|
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||||
|
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|
||||||
|
|
||||||
## Model Inference
|
## Model Inference
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|
||||||
## 模型推理
|
## 模型推理
|
||||||
|
|
||||||
|
|||||||
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
|
||||||
|
)
|
||||||
|
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480))
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
input_image=input_image,
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
24
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
24
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import save_video
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
43
examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py
Normal file
43
examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=False,
|
||||||
|
height=704, width=1248,
|
||||||
|
num_frames=121,
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
|
|
||||||
|
# Image-to-video
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
|
||||||
|
)
|
||||||
|
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704))
|
||||||
|
video = pipe(
|
||||||
|
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=False,
|
||||||
|
height=704, width=1248,
|
||||||
|
input_image=input_image,
|
||||||
|
num_frames=121,
|
||||||
|
)
|
||||||
|
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||||
35
examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh
Normal file
35
examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--extra_inputs "input_image" \
|
||||||
|
--use_gradient_checkpointing_offload \
|
||||||
|
--max_timestep_boundary 1 \
|
||||||
|
--min_timestep_boundary 0.875
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--extra_inputs "input_image" \
|
||||||
|
--use_gradient_checkpointing_offload \
|
||||||
|
--max_timestep_boundary 0.875 \
|
||||||
|
--min_timestep_boundary 0
|
||||||
31
examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh
Normal file
31
examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--max_timestep_boundary 1 \
|
||||||
|
--min_timestep_boundary 0.875
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--max_timestep_boundary 0.875 \
|
||||||
|
--min_timestep_boundary 0
|
||||||
14
examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh
Normal file
14
examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
accelerate launch examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-TI2V-5B_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--extra_inputs "input_image"
|
||||||
37
examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh
Normal file
37
examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
accelerate launch examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--extra_inputs "input_image" \
|
||||||
|
--max_timestep_boundary 1 \
|
||||||
|
--min_timestep_boundary 0.875
|
||||||
|
|
||||||
|
accelerate launch examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--extra_inputs "input_image" \
|
||||||
|
--max_timestep_boundary 0.875 \
|
||||||
|
--min_timestep_boundary 0
|
||||||
36
examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh
Normal file
36
examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
accelerate launch examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--max_timestep_boundary 1 \
|
||||||
|
--min_timestep_boundary 0.875
|
||||||
|
|
||||||
|
|
||||||
|
accelerate launch examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--max_timestep_boundary 0.875 \
|
||||||
|
--min_timestep_boundary 0
|
||||||
16
examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh
Normal file
16
examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
accelerate launch examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-TI2V-5B_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--extra_inputs "input_image"
|
||||||
@@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
use_gradient_checkpointing=True,
|
use_gradient_checkpointing=True,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
extra_inputs=None,
|
extra_inputs=None,
|
||||||
|
max_timestep_boundary=1.0,
|
||||||
|
min_timestep_boundary=0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
@@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.max_timestep_boundary = max_timestep_boundary
|
||||||
|
self.min_timestep_boundary = min_timestep_boundary
|
||||||
|
|
||||||
|
|
||||||
def forward_preprocess(self, data):
|
def forward_preprocess(self, data):
|
||||||
@@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
"cfg_merge": False,
|
"cfg_merge": False,
|
||||||
"vace_scale": 1,
|
"vace_scale": 1,
|
||||||
|
"max_timestep_boundary": self.max_timestep_boundary,
|
||||||
|
"min_timestep_boundary": self.min_timestep_boundary,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extra inputs
|
# Extra inputs
|
||||||
@@ -106,6 +112,8 @@ if __name__ == "__main__":
|
|||||||
lora_rank=args.lora_rank,
|
lora_rank=args.lora_rank,
|
||||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
extra_inputs=args.extra_inputs,
|
extra_inputs=args.extra_inputs,
|
||||||
|
max_timestep_boundary=args.max_timestep_boundary,
|
||||||
|
min_timestep_boundary=args.min_timestep_boundary,
|
||||||
)
|
)
|
||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData, load_state_dict
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors")
|
||||||
|
pipe.dit.load_state_dict(state_dict)
|
||||||
|
state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors")
|
||||||
|
pipe.dit2.load_state_dict(state_dict)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="from sunset to night, a small town, light, house, river",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
input_image=input_image,
|
||||||
|
num_frames=49,
|
||||||
|
seed=1, tiled=False,
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5)
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData, load_state_dict
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors")
|
||||||
|
pipe.dit.load_state_dict(state_dict)
|
||||||
|
state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors")
|
||||||
|
pipe.dit2.load_state_dict(state_dict)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="from sunset to night, a small town, light, house, river",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=1, tiled=True
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5)
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData, load_state_dict
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors")
|
||||||
|
pipe.dit.load_state_dict(state_dict)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="from sunset to night, a small town, light, house, river",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
input_image=input_image,
|
||||||
|
num_frames=49,
|
||||||
|
seed=1, tiled=False,
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
|
||||||
|
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="from sunset to night, a small town, light, house, river",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
input_image=input_image,
|
||||||
|
num_frames=49,
|
||||||
|
seed=1, tiled=False,
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5)
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
|
||||||
|
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="from sunset to night, a small town, light, house, river",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_frames=49,
|
||||||
|
seed=1, tiled=True
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData, load_state_dict
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors", alpha=1)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="from sunset to night, a small town, light, house, river",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
input_image=input_image,
|
||||||
|
num_frames=49,
|
||||||
|
seed=1, tiled=False,
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)
|
||||||
Reference in New Issue
Block a user