From 55951590f5d13ef7d1d5e399e80d983018be2108 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Sun, 20 Jul 2025 18:13:50 +0800 Subject: [PATCH 01/12] support wan2.2 5B T2V --- diffsynth/configs/model_config.py | 4 +- diffsynth/models/wan_video_dit.py | 14 + diffsynth/models/wan_video_vae.py | 576 +++++++++++++++++- diffsynth/pipelines/wan_video_new.py | 3 +- .../model_inference/Wan2.2-TI2V-5B.py | 34 ++ 5 files changed, 628 insertions(+), 3 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 7a0b72b..0903f79 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -58,7 +58,7 @@ from ..models.stepvideo_dit import StepVideoModel from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_image_encoder import WanImageEncoder -from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38 from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_vace import VaceWanModel @@ -140,6 +140,7 @@ model_loader_configs = [ (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"), (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), + (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), @@ -147,6 +148,7 @@ model_loader_configs = [ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), + (None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"), (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 50c06bf..2daf1b4 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -659,6 +659,20 @@ class WanModelStateDictConverter: "add_control_adapter": True, "in_dim_control_adapter": 24, } + elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 3072, + "ffn_dim": 14336, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 48, + "num_heads": 24, + "num_layers": 30, + "eps": 1e-6, + } else: config = {} return state_dict, config diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 137fd28..d737e2f 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -195,6 +195,75 @@ class Resample(nn.Module): nn.init.zeros_(conv.bias.data) + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size) + return x + + +class Resample38(Resample): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super(Resample, self).__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + else: + self.resample = nn.Identity() + class ResidualBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout=0.0): @@ -273,6 +342,178 @@ class AttentionBlock(nn.Module): return x + identity +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample38(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False + ): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample38(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + class Encoder3d(nn.Module): def __init__(self, @@ -376,6 +617,122 @@ class Encoder3d(nn.Module): return x +class Encoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] if i < len(temperal_downsample) else False + ) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + class Decoder3d(nn.Module): def __init__(self, @@ -481,10 +838,112 @@ class Decoder3d(nn.Module): return x + +class Decoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock(in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1)) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1)) + + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + def count_conv3d(model): count = 0 for m in model.modules(): - if check_is_instance(m, CausalConv3d): + if isinstance(m, CausalConv3d): count += 1 return count @@ -798,3 +1257,118 @@ class WanVideoVAEStateDictConverter: for name in state_dict: state_dict_['model.' + name] = state_dict[name] return state_dict_ + + +class VideoVAE38_(VideoVAE_): + + def __init__(self, + dim=160, + z_dim=48, + dec_dim=256, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super(VideoVAE_, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + +class WanVideoVAE38(WanVideoVAE): + + def __init__(self, z_dim=48, dim=160): + super(WanVideoVAE, self).__init__() + + mean = [ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667 + ] + std = [ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) + self.upsampling_factor = 16 diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9f52ddc..91a6f7b 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -679,7 +679,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit): length = (num_frames - 1) // 4 + 1 if vace_reference_image is not None: length += 1 - noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device) + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) if vace_reference_image is not None: noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) return {"noise": noise} diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..93ac975 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import snapshot_download +from diffsynth.models.utils import load_state_dict, hash_state_dict_keys +from modelscope import dataset_snapshot_download + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="model_shards/model-*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1280, +) +save_video(video, "video1.mp4", fps=15, quality=5) From f1f00c425521209eff96f43b092563be908a172a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 21 Jul 2025 14:47:58 +0800 Subject: [PATCH 02/12] support wan2.2 5B I2V --- diffsynth/models/wan_video_dit.py | 20 +++++- diffsynth/pipelines/wan_video_new.py | 66 +++++++++++++++++-- .../model_inference/Wan2.2-TI2V-5B.py | 31 ++++++--- 3 files changed, 99 insertions(+), 18 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 2daf1b4..b7df3ff 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -212,9 +212,16 @@ class DiTBlock(nn.Module): self.gate = GateModule() def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) input_x = modulate(self.norm1(x), shift_msa, scale_msa) x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) x = x + self.cross_attn(self.norm3(x), context) @@ -253,8 +260,12 @@ class Head(nn.Module): self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, t_mod): - shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + scale) + shift)) + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) return x @@ -276,12 +287,14 @@ class WanModel(torch.nn.Module): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, + is_5b: bool = False, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size + self.is_5b = is_5b self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -672,6 +685,7 @@ class WanModelStateDictConverter: "num_heads": 24, "num_layers": 30, "eps": 1e-6, + "is_5b": True, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 91a6f7b..1136403 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -237,6 +237,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), + WanVideoUnit_ImageEmbedder5B(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -736,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None: + if input_image is None or pipe.dit.is_5b: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -764,7 +765,55 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"clip_feature": clip_context, "y": y} - + +class WanVideoUnit_ImageEmbedder5B(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.is_5b: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) + z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + + _, mask2 = self.masks_like([noise.squeeze(0)], zero=True) + latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0) + latents = latents.unsqueeze(0) + + seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2) + if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: + import math + seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size + + return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len} + + @staticmethod + def masks_like(tensor, zero=False, generator=None, p=0.2): + assert isinstance(tensor, list) + out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + + if zero: + if generator is not None: + for u, v in zip(out1, out2): + random_num = torch.rand(1, generator=generator, device=generator.device).item() + if random_num < p: + u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp() + v[:, 0] = torch.zeros_like(v[:, 0]) + else: + u[:, 0] = u[:, 0] + v[:, 0] = v[:, 0] + else: + for u, v in zip(out1, out2): + u[:, 0] = torch.zeros_like(u[:, 0]) + v[:, 0] = torch.zeros_like(v[:, 0]) + + return out1, out2 + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): @@ -1112,9 +1161,16 @@ def model_fn_wan_video( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) - - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) - t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + if dit.is_5b and "mask_5b" in kwargs: + temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten() + temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) + timestep = temp_ts.unsqueeze(0).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index 93ac975..b737b47 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -6,13 +6,6 @@ from modelscope import snapshot_download from diffsynth.models.utils import load_state_dict, hash_state_dict_keys from modelscope import dataset_snapshot_download -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] -) - - pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -26,9 +19,27 @@ pipe.enable_vram_management() # Text-to-video video = pipe( - prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=True, - height=704, width=1280, + seed=0, tiled=False, + height=704, width=1248, + num_frames=121, ) save_video(video, "video1.mp4", fps=15, quality=5) + +# Image-to-video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + height=704, width=1248, + input_image=input_image, + num_frames=121, +) +save_video(video, "video2.mp4", fps=15, quality=5) From 3aed244c6f119c997679e305c8f02c606ae9460b Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 23 Jul 2025 11:20:06 +0800 Subject: [PATCH 03/12] update variable --- diffsynth/models/wan_video_dit.py | 6 +++--- diffsynth/pipelines/wan_video_new.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index b7df3ff..1106669 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -287,14 +287,14 @@ class WanModel(torch.nn.Module): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, - is_5b: bool = False, + seperated_timestep: bool = False, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size - self.is_5b = is_5b + self.seperated_timestep = seperated_timestep self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -685,7 +685,7 @@ class WanModelStateDictConverter: "num_heads": 24, "num_layers": 30, "eps": 1e-6, - "is_5b": True, + "seperated_timestep": True, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 1136403..7ea0e3c 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -237,7 +237,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), - WanVideoUnit_ImageEmbedder5B(), + WanVideoUnit_ImageVaeEmbedder(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -737,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.dit.is_5b: + if input_image is None or pipe.dit.seperated_timestep: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -766,7 +766,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): return {"clip_feature": clip_context, "y": y} -class WanVideoUnit_ImageEmbedder5B(PipelineUnit): +class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -774,7 +774,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or not pipe.dit.is_5b: + if input_image is None or not pipe.dit.seperated_timestep: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) @@ -789,7 +789,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit): import math seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size - return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len} + return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len} @staticmethod def masks_like(tensor, zero=False, generator=None, p=0.2): @@ -1162,8 +1162,8 @@ def model_fn_wan_video( get_sequence_parallel_world_size, get_sp_group) - if dit.is_5b and "mask_5b" in kwargs: - temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten() + if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs: + temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten() temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) timestep = temp_ts.unsqueeze(0).flatten() t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) From 9015d08927331a7b5b559ed17412558279690c33 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 25 Jul 2025 17:09:53 +0800 Subject: [PATCH 04/12] support wan2.2 A14B I2V&T2V --- diffsynth/configs/model_config.py | 2 + diffsynth/models/wan_video_dit.py | 19 ++++ diffsynth/pipelines/wan_video_new.py | 96 ++++++++++++++++++- .../model_inference/Wan2.2-I2V-A14B.py | 32 +++++++ .../model_inference/Wan2.2-T2V-A14B.py | 27 ++++++ .../model_inference/Wan2.2-TI2V-5B.py | 8 +- 6 files changed, 175 insertions(+), 9 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 0903f79..6448fbf 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -141,6 +141,8 @@ model_loader_configs = [ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), + (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), + (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 1106669..3262057 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -352,6 +352,7 @@ class WanModel(torch.nn.Module): context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + fused_y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, @@ -365,6 +366,8 @@ class WanModel(torch.nn.Module): x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + if fused_y is not None: + x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) x, (f, h, w) = self.patchify(x) @@ -673,6 +676,7 @@ class WanModelStateDictConverter: "in_dim_control_adapter": 24, } elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": + # Wan-AI/Wan2.2-TI2V-5B config = { "has_image_input": False, "patch_size": [1, 2, 2], @@ -687,6 +691,21 @@ class WanModelStateDictConverter: "eps": 1e-6, "seperated_timestep": True, } + elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": + # Wan-AI/Wan2.2-I2V-A14B + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + } else: config = {} return state_dict, config diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 7ea0e3c..d108ad4 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -226,10 +226,11 @@ class WanVideoPipeline(BasePipeline): self.text_encoder: WanTextEncoder = None self.image_encoder: WanImageEncoder = None self.dit: WanModel = None + self.dit2: WanModel = None self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None - self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), @@ -238,6 +239,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), WanVideoUnit_ImageVaeEmbedder(), + WanVideoUnit_ImageEmbedderNoClip(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -329,6 +331,37 @@ class WanVideoPipeline(BasePipeline): ), vram_limit=vram_limit, ) + if self.dit2 is not None: + dtype = next(iter(self.dit2.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.dit2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) if self.vae is not None: dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( @@ -427,6 +460,10 @@ class WanVideoPipeline(BasePipeline): for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @@ -473,6 +510,9 @@ class WanVideoPipeline(BasePipeline): # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") pipe.dit = model_manager.fetch_model("wan_video_dit") + num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"]) + if num_dits == 2: + pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1] pipe.vae = model_manager.fetch_model("wan_video_vae") pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") @@ -523,6 +563,8 @@ class WanVideoPipeline(BasePipeline): # Classifier-free guidance cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, + # Boundary + boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, @@ -575,8 +617,12 @@ class WanVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # switch high_noise DiT to low_noise DiT + if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: + print("switching to low noise DiT") + self.load_models_to_device(["dit2", "motion_controller", "vace"]) + models["dit"] = models.pop("dit2") timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: @@ -737,7 +783,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.dit.seperated_timestep: + if input_image is None or pipe.image_encoder is None: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -767,6 +813,9 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ def __init__(self): super().__init__( input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -815,6 +864,42 @@ class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): return out1, out2 +class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): + """ + Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"fused_y": y} + + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( @@ -1116,6 +1201,7 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + fused_y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, @@ -1181,11 +1267,13 @@ def model_fn_wan_video( x = torch.concat([x] * context.shape[0], dim=0) if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) - + if dit.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + if fused_y is not None: + x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..9782a2c --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..de9ae5f --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -0,0 +1,27 @@ +import torch +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import snapshot_download + +snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B") + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index b737b47..f41a941 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -1,17 +1,15 @@ import torch from PIL import Image -from diffsynth import save_video, VideoData +from diffsynth import save_video from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import snapshot_download -from diffsynth.models.utils import load_state_dict, hash_state_dict_keys from modelscope import dataset_snapshot_download pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="model_shards/model-*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), ], ) From bba44173d2a838c0c5d7df1a33729aedc1ac5839 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 25 Jul 2025 17:24:42 +0800 Subject: [PATCH 05/12] minor fix --- diffsynth/configs/model_config.py | 1 - diffsynth/pipelines/wan_video_new.py | 1 - 2 files changed, 2 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 6448fbf..6665599 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -141,7 +141,6 @@ model_loader_configs = [ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), - (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index d108ad4..f27382d 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -619,7 +619,6 @@ class WanVideoPipeline(BasePipeline): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): # switch high_noise DiT to low_noise DiT if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: - print("switching to low noise DiT") self.load_models_to_device(["dit2", "motion_controller", "vace"]) models["dit"] = models.pop("dit2") timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) From 5f68727ad31c07ea3ce45f6dab2787854a711a19 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 28 Jul 2025 11:00:54 +0800 Subject: [PATCH 06/12] refine code --- diffsynth/models/model_manager.py | 21 ++- diffsynth/models/wan_video_dit.py | 13 +- diffsynth/pipelines/wan_video_new.py | 155 +++++++++--------- diffsynth/trainers/utils.py | 2 + examples/wanvideo/README.md | 3 + examples/wanvideo/README_zh.md | 3 + .../model_inference/Wan2.2-I2V-A14B.py | 2 +- .../model_inference/Wan2.2-T2V-A14B.py | 3 - .../model_inference/Wan2.2-TI2V-5B.py | 2 +- .../model_training/full/Wan2.2-I2V-A14B.sh | 35 ++++ .../model_training/full/Wan2.2-T2V-A14B.sh | 31 ++++ .../model_training/full/Wan2.2-TI2V-5B.sh | 14 ++ .../model_training/lora/Wan2.2-I2V-A14B.sh | 37 +++++ .../model_training/lora/Wan2.2-T2V-A14B.sh | 36 ++++ .../model_training/lora/Wan2.2-TI2V-5B.sh | 16 ++ examples/wanvideo/model_training/train.py | 8 + .../validate_full/Wan2.2-I2V-A14B.py | 33 ++++ .../validate_full/Wan2.2-T2V-A14B.py | 28 ++++ .../validate_full/Wan2.2-TI2V-5B.py | 30 ++++ .../validate_lora/Wan2.2-I2V-A14B.py | 31 ++++ .../validate_lora/Wan2.2-T2V-A14B.py | 28 ++++ .../validate_lora/Wan2.2-TI2V-5B.py | 29 ++++ 22 files changed, 474 insertions(+), 86 deletions(-) create mode 100644 examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7ae3c50..d46eedf 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -426,7 +426,7 @@ class ModelManager: self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) - def fetch_model(self, model_name, file_path=None, require_model_path=False): + def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None): fetched_models = [] fetched_model_paths = [] for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): @@ -440,12 +440,25 @@ class ModelManager: return None if len(fetched_models) == 1: print(f"Using {model_name} from {fetched_model_paths[0]}.") + model = fetched_models[0] + path = fetched_model_paths[0] else: - print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + if index is None: + model = fetched_models[0] + path = fetched_model_paths[0] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + elif isinstance(index, int): + model = fetched_models[:index] + path = fetched_model_paths[:index] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.") + else: + model = fetched_models + path = fetched_model_paths + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.") if require_model_path: - return fetched_models[0], fetched_model_paths[0] + return model, path else: - return fetched_models[0] + return model def to(self, device): diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 3262057..ea473b0 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -288,6 +288,9 @@ class WanModel(torch.nn.Module): add_control_adapter: bool = False, in_dim_control_adapter: int = 24, seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, ): super().__init__() self.dim = dim @@ -295,6 +298,9 @@ class WanModel(torch.nn.Module): self.has_image_input = has_image_input self.patch_size = patch_size self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -352,7 +358,6 @@ class WanModel(torch.nn.Module): context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, - fused_y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, @@ -366,8 +371,6 @@ class WanModel(torch.nn.Module): x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - if fused_y is not None: - x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) x, (f, h, w) = self.patchify(x) @@ -690,6 +693,9 @@ class WanModelStateDictConverter: "num_layers": 30, "eps": 1e-6, "seperated_timestep": True, + "require_clip_embedding": False, + "require_vae_embedding": False, + "fuse_vae_embedding_in_latents": True, } elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": # Wan-AI/Wan2.2-I2V-A14B @@ -705,6 +711,7 @@ class WanModelStateDictConverter: "num_heads": 40, "num_layers": 40, "eps": 1e-6, + "require_clip_embedding": False, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index f27382d..9963aa1 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -230,16 +230,17 @@ class WanVideoPipeline(BasePipeline): self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None - self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace") + self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), - WanVideoUnit_ImageEmbedder(), - WanVideoUnit_ImageVaeEmbedder(), - WanVideoUnit_ImageEmbedderNoClip(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -259,7 +260,9 @@ class WanVideoPipeline(BasePipeline): def training_loss(self, **inputs): - timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) @@ -517,6 +520,11 @@ class WanVideoPipeline(BasePipeline): pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 # Initialize tokenizer tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download) @@ -564,7 +572,7 @@ class WanVideoPipeline(BasePipeline): cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, # Boundary - boundary: Optional[float] = 0.875, + switch_DiT_boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, @@ -617,11 +625,14 @@ class WanVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - # switch high_noise DiT to low_noise DiT - if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: - self.load_models_to_device(["dit2", "motion_controller", "vace"]) - models["dit"] = models.pop("dit2") + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + + # Timestep timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: @@ -775,6 +786,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit): class WanVideoUnit_ImageEmbedder(PipelineUnit): + """ + Deprecated + """ def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -811,70 +825,38 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): return {"clip_feature": clip_context, "y": y} -class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): - """ - Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. - """ + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae") + input_params=("input_image", "end_image", "height", "width"), + onload_model_names=("image_encoder",) ) - def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or not pipe.dit.seperated_timestep: + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: return {} pipe.load_models_to_device(self.onload_model_names) - image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) - z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + - _, mask2 = self.masks_like([noise.squeeze(0)], zero=True) - latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0) - latents = latents.unsqueeze(0) - seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2) - if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: - import math - seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size - - return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len} - - @staticmethod - def masks_like(tensor, zero=False, generator=None, p=0.2): - assert isinstance(tensor, list) - out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] - out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] - - if zero: - if generator is not None: - for u, v in zip(out1, out2): - random_num = torch.rand(1, generator=generator, device=generator.device).item() - if random_num < p: - u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp() - v[:, 0] = torch.zeros_like(v[:, 0]) - else: - u[:, 0] = u[:, 0] - v[:, 0] = v[:, 0] - else: - for u, v in zip(out1, out2): - u[:, 0] = torch.zeros_like(u[:, 0]) - v[:, 0] = torch.zeros_like(v[:, 0]) - - return out1, out2 - - -class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): - """ - Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B. - """ +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep: + if input_image is None or not pipe.dit.require_vae_embedding: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -896,14 +878,36 @@ class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): y = torch.concat([msk, y]) y = y.unsqueeze(0) y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - return {"fused_y": y} + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True} + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): @@ -927,7 +931,7 @@ class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( input_params=("reference_image", "height", "width", "reference_image"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, reference_image, height, width): @@ -1200,7 +1204,6 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, - fused_y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, @@ -1213,6 +1216,7 @@ def model_fn_wan_video( use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -1247,15 +1251,19 @@ def model_fn_wan_video( get_sequence_parallel_world_size, get_sp_group) - if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs: - temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten() - temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) - timestep = temp_ts.unsqueeze(0).flatten() - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) else: t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -1267,16 +1275,15 @@ def model_fn_wan_video( if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) - if dit.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - if fused_y is not None: - x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) - # Reference image if reference_latents is not None: diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b171857..b27fc34 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -434,6 +434,8 @@ def wan_parser(): parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") return parser diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 7b2ebbd..c75e518 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## Model Inference diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 860ff83..822c8e6 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## 模型推理 diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py index 9782a2c..0c1be54 100644 --- a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -22,7 +22,7 @@ dataset_snapshot_download( allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] ) input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) -# Text-to-video + video = pipe( prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py index de9ae5f..27b10d0 100644 --- a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -1,9 +1,6 @@ import torch from diffsynth import save_video from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import snapshot_download - -snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B") pipe = WanVideoPipeline.from_pretrained( diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index f41a941..50d81c2 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -10,7 +10,7 @@ pipe = WanVideoPipeline.from_pretrained( model_configs=[ ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), ], ) pipe.enable_vram_management() diff --git a/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..2f531e7 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh @@ -0,0 +1,35 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..f634117 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh @@ -0,0 +1,31 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..def9f89 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..4201b47 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh @@ -0,0 +1,37 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..737896c --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh @@ -0,0 +1,36 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..6a33b57 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 877c5b8..93ec0bd 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule): use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, + max_timestep_boundary=1.0, + min_timestep_boundary=0.0, ): super().__init__() # Load models @@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.max_timestep_boundary = max_timestep_boundary + self.min_timestep_boundary = min_timestep_boundary def forward_preprocess(self, data): @@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule): "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "cfg_merge": False, "vace_scale": 1, + "max_timestep_boundary": self.max_timestep_boundary, + "min_timestep_boundary": self.min_timestep_boundary, } # Extra inputs @@ -106,6 +112,8 @@ if __name__ == "__main__": lora_rank=args.lora_rank, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, + max_timestep_boundary=args.max_timestep_boundary, + min_timestep_boundary=args.min_timestep_boundary, ) model_logger = ModelLogger( args.output_path, diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..ddcdf5c --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..be0e000 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..0f0ea5d --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..4a6bd9c --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..ab43927 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..d5a9229 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) From 05dba91f79f4b4195100b2a61e407c263ae21265 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 28 Jul 2025 13:38:01 +0800 Subject: [PATCH 07/12] fix wan2.2 5B --- diffsynth/pipelines/wan_video_new.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9963aa1..1fa930d 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -646,6 +646,8 @@ class WanVideoPipeline(BasePipeline): # Scheduler inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] # VACE (TODO: remove it) if vace_reference_image is not None: @@ -899,7 +901,7 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit): image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) latents[:, :, 0: 1] = z - return {"latents": latents, "fuse_vae_embedding_in_latents": True} + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} From 729c512c6656ed608de85fc1b671d09855c21590 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 28 Jul 2025 15:18:47 +0800 Subject: [PATCH 08/12] bugfix --- diffsynth/pipelines/wan_video_new.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 8b47508..4f62a74 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -706,7 +706,8 @@ class WanVideoUnit_FunReference(PipelineUnit): class WanVideoUnit_FunCameraControl(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image") + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"), + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image): @@ -729,6 +730,7 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): input_image = input_image.resize((width, height)) input_latents = pipe.preprocess_video([input_image]) + pipe.load_models_to_device(self.onload_model_names) input_latents = pipe.vae.encode(input_latents, device=pipe.device) y = torch.zeros_like(latents).to(pipe.device) y[:, :, :1] = input_latents From 29663b25a65d35f129a88dadbe28e74e895c0a49 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 28 Jul 2025 16:49:28 +0800 Subject: [PATCH 09/12] fix wan2.2 vae --- diffsynth/models/wan_video_vae.py | 8 +++++--- examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index d737e2f..397a2e7 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -1075,6 +1075,7 @@ class WanVideoVAE(nn.Module): # init model self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) self.upsampling_factor = 8 + self.z_dim = z_dim def build_1d_mask(self, length, left_bound, right_bound, border_width): @@ -1170,7 +1171,7 @@ class WanVideoVAE(nn.Module): out_T = (T + 3) // 4 weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) - values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) @@ -1221,8 +1222,8 @@ class WanVideoVAE(nn.Module): for video in videos: video = video.unsqueeze(0) if tiled: - tile_size = (tile_size[0] * 8, tile_size[1] * 8) - tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor) + tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor) hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) else: hidden_state = self.single_encode(video, device) @@ -1372,3 +1373,4 @@ class WanVideoVAE38(WanVideoVAE): # init model self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) self.upsampling_factor = 16 + self.z_dim = z_dim diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index 50d81c2..fa91965 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -19,7 +19,7 @@ pipe.enable_vram_management() video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=False, + seed=0, tiled=True, height=704, width=1248, num_frames=121, ) @@ -35,7 +35,7 @@ input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 70 video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=False, + seed=0, tiled=True, height=704, width=1248, input_image=input_image, num_frames=121, From 68aafab09e4d8a9c7b215c6f35e7a38b2dfe3108 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 28 Jul 2025 17:02:30 +0800 Subject: [PATCH 10/12] update readme --- README.md | 7 ++++++- README_zh.md | 7 ++++++- diffsynth/pipelines/wan_video_new.py | 9 +++++---- examples/wanvideo/README.md | 7 +++++-- examples/wanvideo/README_zh.md | 7 +++++-- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 11403a5..8c01734 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,9 @@ image.save("image.jpg") |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| @@ -317,7 +320,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History -- **July 11, 2025** 🔥🔥🔥 We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. +- **July 28, 2025** 🔥🔥🔥 With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/). + +- **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) - Github Repo: https://github.com/modelscope/Nexus-Gen - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) diff --git a/README_zh.md b/README_zh.md index 650d2ec..bfcaa46 100644 --- a/README_zh.md +++ b/README_zh.md @@ -167,6 +167,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| @@ -333,7 +336,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 -- **2025年7月11日** 🔥🔥🔥 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 +- **2025年7月28日** 🔥🔥🔥 Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 + +- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 - 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) - Github 仓库: https://github.com/modelscope/Nexus-Gen - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 4f62a74..a425da2 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -323,10 +323,11 @@ class WanVideoPipeline(BasePipeline): # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") - pipe.dit = model_manager.fetch_model("wan_video_dit") - num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"]) - if num_dits == 2: - pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1] + dit = model_manager.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit pipe.vae = model_manager.fetch_model("wan_video_vae") pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 6d967c5..8231e29 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -1,8 +1,8 @@ -# Wan 2.1 +# Wan [切换到中文](./README_zh.md) -Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba. +Wan is a collection of video synthesis models open-sourced by Alibaba. **DiffSynth-Studio has adopted a new inference and training framework. To use the previous version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).** @@ -247,6 +247,7 @@ The pipeline accepts the following input parameters during inference: * `num_frames`: Number of frames, default is 81. Must be a multiple of 4 plus 1; if not, it will be rounded up, minimum is 1. * `cfg_scale`: Classifier-free guidance scale, default is 5. Higher values increase adherence to the prompt but may cause visual artifacts. * `cfg_merge`: Whether to merge both sides of classifier-free guidance for unified inference. Default is `False`. This parameter currently only works for basic text-to-video and image-to-video models. +* `switch_DiT_boundary`: The time point for switching between DiT models. Default value is 0.875. This parameter only takes effect for mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B). * `num_inference_steps`: Number of inference steps, default is 50. * `sigma_shift`: Parameter from Rectified Flow theory, default is 5. Higher values make the model stay longer at the initial denoising stage. Increasing this may improve video quality but may also cause inconsistency between generated videos and training data due to deviation from training behavior. * `motion_bucket_id`: Motion intensity, range [0, 100], applicable to motion control modules such as [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1). Larger values indicate more intense motion. @@ -282,6 +283,8 @@ The script includes the following parameters: * Models * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. + * `--max_timestep_boundary`: Maximum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B). + * `--min_timestep_boundary`: Minimum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B). * Training * `--learning_rate`: Learning rate. * `--num_epochs`: Number of epochs. diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 54b07da..9f26858 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -1,8 +1,8 @@ -# 通义万相 2.1(Wan 2.1) +# 通义万相(Wan) [Switch to English](./README.md) -Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。 +Wan 是由阿里巴巴通义实验室开源的一系列视频生成模型。 **DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。** @@ -248,6 +248,7 @@ Pipeline 在推理阶段能够接收以下输入参数: * `num_frames`: 帧数,默认为 81。需设置为 4 的倍数 + 1,不满足时向上取整,最小值为 1。 * `cfg_scale`: Classifier-free guidance 机制的数值,默认为 5。数值越大,提示词的控制效果越强,但画面崩坏的概率越大。 * `cfg_merge`: 是否合并 Classifier-free guidance 的两侧进行统一推理,默认为 `False`。该参数目前仅在基础的文生视频和图生视频模型上生效。 +* `switch_DiT_boundary`: 切换 DiT 模型的时间点,默认值为 0.875,仅对多 DiT 的混合模型生效,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。 * `num_inference_steps`: 推理次数,默认值为 50。 * `sigma_shift`: Rectified Flow 理论中的参数,默认为 5。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的视频内容与训练数据存在差异。 * `motion_bucket_id`: 运动幅度,范围为 [0, 100]。适用于速度控制模块,例如 [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1),数值越大,运动幅度越大。 @@ -284,6 +285,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 + * `--max_timestep_boundary`: Timestep 区间最大值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。 + * `--min_timestep_boundary`: Timestep 区间最小值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。 * 训练 * `--learning_rate`: 学习率。 * `--num_epochs`: 轮数(Epoch)。 From 283f35447a129b04e44a87a28a201c2d3f61d33a Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 28 Jul 2025 18:25:43 +0800 Subject: [PATCH 11/12] refine readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8c01734..87bca13 100644 --- a/README.md +++ b/README.md @@ -101,9 +101,6 @@ image.save("image.jpg") |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| -|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| -|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| -|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| @@ -170,6 +167,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| From 9e683bfe25886b5de5eaefd865ceae61796cd14e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 28 Jul 2025 18:30:04 +0800 Subject: [PATCH 12/12] fix typo --- .../wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py | 2 +- .../wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py | 2 +- .../wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py | 2 +- .../wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py index ddcdf5c..3f6d253 100644 --- a/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py @@ -28,6 +28,6 @@ video = pipe( negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", input_image=input_image, num_frames=49, - seed=1, tiled=False, + seed=1, tiled=True, ) save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py index 0f0ea5d..8a3d36c 100644 --- a/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py @@ -25,6 +25,6 @@ video = pipe( negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", input_image=input_image, num_frames=49, - seed=1, tiled=False, + seed=1, tiled=True, ) save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py index 4a6bd9c..f221ef7 100644 --- a/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py @@ -26,6 +26,6 @@ video = pipe( negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", input_image=input_image, num_frames=49, - seed=1, tiled=False, + seed=1, tiled=True, ) save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py index d5a9229..e5b16c8 100644 --- a/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py @@ -24,6 +24,6 @@ video = pipe( negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", input_image=input_image, num_frames=49, - seed=1, tiled=False, + seed=1, tiled=True, ) save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)