Merge pull request #704 from modelscope/wan2.2

Wan2.2
This commit is contained in:
Zhongjie Duan
2025-07-28 15:06:01 +08:00
committed by GitHub
24 changed files with 1283 additions and 23 deletions

View File

@@ -58,7 +58,7 @@ from ..models.stepvideo_dit import StepVideoModel
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_vace import VaceWanModel
@@ -141,6 +141,8 @@ model_loader_configs = [
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
@@ -148,6 +150,7 @@ model_loader_configs = [
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"),
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),

View File

@@ -426,7 +426,7 @@ class ModelManager:
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
@@ -440,12 +440,25 @@ class ModelManager:
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
model = fetched_models[0]
path = fetched_model_paths[0]
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if index is None:
model = fetched_models[0]
path = fetched_model_paths[0]
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
elif isinstance(index, int):
model = fetched_models[:index]
path = fetched_model_paths[:index]
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
else:
model = fetched_models
path = fetched_model_paths
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
return model, path
else:
return fetched_models[0]
return model
def to(self, device):

View File

@@ -212,9 +212,16 @@ class DiTBlock(nn.Module):
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
# msa: multi-head self-attention mlp: multi-layer perceptron
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context)
@@ -253,8 +260,12 @@ class Head(nn.Module):
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
if len(t_mod.shape) == 3:
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
else:
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
@@ -276,12 +287,20 @@ class WanModel(torch.nn.Module):
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
@@ -659,6 +678,41 @@ class WanModelStateDictConverter:
"add_control_adapter": True,
"in_dim_control_adapter": 24,
}
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
# Wan-AI/Wan2.2-TI2V-5B
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 3072,
"ffn_dim": 14336,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 48,
"num_heads": 24,
"num_layers": 30,
"eps": 1e-6,
"seperated_timestep": True,
"require_clip_embedding": False,
"require_vae_embedding": False,
"fuse_vae_embedding_in_latents": True,
}
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
# Wan-AI/Wan2.2-I2V-A14B
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"require_clip_embedding": False,
}
else:
config = {}
return state_dict, config

View File

@@ -195,6 +195,75 @@ class Resample(nn.Module):
nn.init.zeros_(conv.bias.data)
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size)
return x
class Resample38(Resample):
def __init__(self, dim, mode):
assert mode in (
"none",
"upsample2d",
"upsample3d",
"downsample2d",
"downsample3d",
)
super(Resample, self).__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim, 3, padding=1),
)
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else:
self.resample = nn.Identity()
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
@@ -273,6 +342,178 @@ class AttentionBlock(nn.Module):
return x + identity
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1 :, :, :]
return x
class Down_ResidualBlock(nn.Module):
def __init__(
self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False
):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
downsamples.append(Resample38(out_dim, mode=mode))
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(
self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False
):
super().__init__()
# Shortcut path with upsample
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
upsamples.append(Resample38(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder3d(nn.Module):
def __init__(self,
@@ -376,6 +617,122 @@ class Encoder3d(nn.Module):
return x
class Encoder3d_38(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = (
temperal_downsample[i] if i < len(temperal_downsample) else False
)
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
down_flag=i != len(dim_mult) - 1,
)
)
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout),
)
# # output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
@@ -481,10 +838,112 @@ class Decoder3d(nn.Module):
return x
class Decoder3d_38(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
up_flag=i != len(dim_mult) - 1))
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 12, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if check_is_instance(m, CausalConv3d):
if isinstance(m, CausalConv3d):
count += 1
return count
@@ -798,3 +1257,118 @@ class WanVideoVAEStateDictConverter:
for name in state_dict:
state_dict_['model.' + name] = state_dict[name]
return state_dict_
class VideoVAE38_(VideoVAE_):
def __init__(self,
dim=160,
z_dim=48,
dec_dim=256,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0):
super(VideoVAE_, self).__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x, scale):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=mu.dtype, device=mu.device)
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale):
self.clear_cache()
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
class WanVideoVAE38(WanVideoVAE):
def __init__(self, z_dim=48, dim=160):
super(WanVideoVAE, self).__init__()
mean = [
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667
]
std = [
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
self.upsampling_factor = 16

View File

@@ -39,17 +39,21 @@ class WanVideoPipeline(BasePipeline):
self.text_encoder: WanTextEncoder = None
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.dit2: WanModel = None
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
WanVideoUnit_NoiseInitializer(),
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_ImageEmbedderVAE(),
WanVideoUnit_ImageEmbedderCLIP(),
WanVideoUnit_ImageEmbedderFused(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_FunCameraControl(),
@@ -69,7 +73,9 @@ class WanVideoPipeline(BasePipeline):
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
@@ -141,6 +147,37 @@ class WanVideoPipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.dit2 is not None:
dtype = next(iter(self.dit2.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
@@ -239,6 +276,10 @@ class WanVideoPipeline(BasePipeline):
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
if self.dit2 is not None:
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@@ -283,10 +324,18 @@ class WanVideoPipeline(BasePipeline):
# Load models
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"])
if num_dits == 2:
pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1]
pipe.vae = model_manager.fetch_model("wan_video_vae")
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace")
# Size division factor
if pipe.vae is not None:
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
# Initialize tokenizer
tokenizer_config.download_if_necessary(use_usp=use_usp)
@@ -333,6 +382,8 @@ class WanVideoPipeline(BasePipeline):
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Boundary
switch_DiT_boundary: Optional[float] = 0.875,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
@@ -385,8 +436,14 @@ class WanVideoPipeline(BasePipeline):
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["dit"] = self.dit2
# Timestep
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
@@ -400,6 +457,8 @@ class WanVideoPipeline(BasePipeline):
# Scheduler
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
if "first_frame_latents" in inputs_shared:
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
# VACE (TODO: remove it)
if vace_reference_image is not None:
@@ -433,7 +492,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
length = (num_frames - 1) // 4 + 1
if vace_reference_image is not None:
length += 1
noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device)
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
if vace_reference_image is not None:
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
return {"noise": noise}
@@ -482,6 +542,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
class WanVideoUnit_ImageEmbedder(PipelineUnit):
"""
Deprecated
"""
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
@@ -489,7 +552,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None:
if input_image is None or pipe.image_encoder is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
@@ -517,13 +580,90 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "height", "width"),
onload_model_names=("image_encoder",)
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
if pipe.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context}
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.require_vae_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"y": y}
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
"""
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
"""
def __init__(self):
super().__init__(
input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents[:, :, 0: 1] = z
return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
onload_model_names=("vae")
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
@@ -547,7 +687,7 @@ class WanVideoUnit_FunReference(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("reference_image", "height", "width", "reference_image"),
onload_model_names=("vae")
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
@@ -832,6 +972,7 @@ def model_fn_wan_video(
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -865,9 +1006,20 @@ def model_fn_wan_video(
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
# Timestep
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
timestep = torch.concat([
torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
]).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
# Motion Controller
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
@@ -878,15 +1030,16 @@ def model_fn_wan_video(
x = torch.concat([x] * context.shape[0], dim=0)
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
if dit.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
# Image Embedding
if y is not None and dit.require_vae_embedding:
x = torch.cat([x, y], dim=1)
if clip_feature is not None and dit.require_clip_embedding:
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
# Add camera control
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
# Reference image
if reference_latents is not None:

View File

@@ -434,6 +434,8 @@ def wan_parser():
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
return parser

View File

@@ -67,6 +67,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
## Model Inference

View File

@@ -67,6 +67,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
## 模型推理

View File

@@ -0,0 +1,32 @@
import torch
from PIL import Image
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
)
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480))
video = pipe(
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
input_image=input_image,
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,24 @@
import torch
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
# Text-to-video
video = pipe(
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,43 @@
import torch
from PIL import Image
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
# Text-to-video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=704, width=1248,
num_frames=121,
)
save_video(video, "video1.mp4", fps=15, quality=5)
# Image-to-video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
)
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704))
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=704, width=1248,
input_image=input_image,
num_frames=121,
)
save_video(video, "video2.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,35 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,31 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,14 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-TI2V-5B_full" \
--trainable_models "dit" \
--extra_inputs "input_image"

View File

@@ -0,0 +1,37 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,36 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,16 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-TI2V-5B_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image"

View File

@@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule):
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
max_timestep_boundary=1.0,
min_timestep_boundary=0.0,
):
super().__init__()
# Load models
@@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule):
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.max_timestep_boundary = max_timestep_boundary
self.min_timestep_boundary = min_timestep_boundary
def forward_preprocess(self, data):
@@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule):
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"cfg_merge": False,
"vace_scale": 1,
"max_timestep_boundary": self.max_timestep_boundary,
"min_timestep_boundary": self.min_timestep_boundary,
}
# Extra inputs
@@ -106,6 +112,8 @@ if __name__ == "__main__":
lora_rank=args.lora_rank,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
max_timestep_boundary=args.max_timestep_boundary,
min_timestep_boundary=args.min_timestep_boundary,
)
model_logger = ModelLogger(
args.output_path,

View File

@@ -0,0 +1,33 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors")
pipe.dit2.load_state_dict(state_dict)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,28 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors")
pipe.dit2.load_state_dict(state_dict)
pipe.enable_vram_management()
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=1, tiled=True
)
save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,30 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,31 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.dit, "models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,28 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.dit, "models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_frames=49,
seed=1, tiled=True
)
save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,29 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.dit, "models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)