From 55951590f5d13ef7d1d5e399e80d983018be2108 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Sun, 20 Jul 2025 18:13:50 +0800 Subject: [PATCH 1/7] support wan2.2 5B T2V --- diffsynth/configs/model_config.py | 4 +- diffsynth/models/wan_video_dit.py | 14 + diffsynth/models/wan_video_vae.py | 576 +++++++++++++++++- diffsynth/pipelines/wan_video_new.py | 3 +- .../model_inference/Wan2.2-TI2V-5B.py | 34 ++ 5 files changed, 628 insertions(+), 3 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 7a0b72b..0903f79 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -58,7 +58,7 @@ from ..models.stepvideo_dit import StepVideoModel from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_image_encoder import WanImageEncoder -from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38 from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_vace import VaceWanModel @@ -140,6 +140,7 @@ model_loader_configs = [ (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"), (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), + (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), @@ -147,6 +148,7 @@ model_loader_configs = [ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), + (None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"), (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 50c06bf..2daf1b4 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -659,6 +659,20 @@ class WanModelStateDictConverter: "add_control_adapter": True, "in_dim_control_adapter": 24, } + elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 3072, + "ffn_dim": 14336, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 48, + "num_heads": 24, + "num_layers": 30, + "eps": 1e-6, + } else: config = {} return state_dict, config diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 137fd28..d737e2f 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -195,6 +195,75 @@ class Resample(nn.Module): nn.init.zeros_(conv.bias.data) + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size) + return x + + +class Resample38(Resample): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super(Resample, self).__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + else: + self.resample = nn.Identity() + class ResidualBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout=0.0): @@ -273,6 +342,178 @@ class AttentionBlock(nn.Module): return x + identity +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample38(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False + ): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample38(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + class Encoder3d(nn.Module): def __init__(self, @@ -376,6 +617,122 @@ class Encoder3d(nn.Module): return x +class Encoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] if i < len(temperal_downsample) else False + ) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + class Decoder3d(nn.Module): def __init__(self, @@ -481,10 +838,112 @@ class Decoder3d(nn.Module): return x + +class Decoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock(in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1)) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1)) + + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + def count_conv3d(model): count = 0 for m in model.modules(): - if check_is_instance(m, CausalConv3d): + if isinstance(m, CausalConv3d): count += 1 return count @@ -798,3 +1257,118 @@ class WanVideoVAEStateDictConverter: for name in state_dict: state_dict_['model.' + name] = state_dict[name] return state_dict_ + + +class VideoVAE38_(VideoVAE_): + + def __init__(self, + dim=160, + z_dim=48, + dec_dim=256, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super(VideoVAE_, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + +class WanVideoVAE38(WanVideoVAE): + + def __init__(self, z_dim=48, dim=160): + super(WanVideoVAE, self).__init__() + + mean = [ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667 + ] + std = [ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) + self.upsampling_factor = 16 diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9f52ddc..91a6f7b 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -679,7 +679,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit): length = (num_frames - 1) // 4 + 1 if vace_reference_image is not None: length += 1 - noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device) + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) if vace_reference_image is not None: noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) return {"noise": noise} diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..93ac975 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import snapshot_download +from diffsynth.models.utils import load_state_dict, hash_state_dict_keys +from modelscope import dataset_snapshot_download + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="model_shards/model-*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1280, +) +save_video(video, "video1.mp4", fps=15, quality=5) From f1f00c425521209eff96f43b092563be908a172a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 21 Jul 2025 14:47:58 +0800 Subject: [PATCH 2/7] support wan2.2 5B I2V --- diffsynth/models/wan_video_dit.py | 20 +++++- diffsynth/pipelines/wan_video_new.py | 66 +++++++++++++++++-- .../model_inference/Wan2.2-TI2V-5B.py | 31 ++++++--- 3 files changed, 99 insertions(+), 18 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 2daf1b4..b7df3ff 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -212,9 +212,16 @@ class DiTBlock(nn.Module): self.gate = GateModule() def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) input_x = modulate(self.norm1(x), shift_msa, scale_msa) x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) x = x + self.cross_attn(self.norm3(x), context) @@ -253,8 +260,12 @@ class Head(nn.Module): self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, t_mod): - shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + scale) + shift)) + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) return x @@ -276,12 +287,14 @@ class WanModel(torch.nn.Module): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, + is_5b: bool = False, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size + self.is_5b = is_5b self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -672,6 +685,7 @@ class WanModelStateDictConverter: "num_heads": 24, "num_layers": 30, "eps": 1e-6, + "is_5b": True, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 91a6f7b..1136403 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -237,6 +237,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), + WanVideoUnit_ImageEmbedder5B(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -736,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None: + if input_image is None or pipe.dit.is_5b: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -764,7 +765,55 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"clip_feature": clip_context, "y": y} - + +class WanVideoUnit_ImageEmbedder5B(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.is_5b: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) + z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + + _, mask2 = self.masks_like([noise.squeeze(0)], zero=True) + latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0) + latents = latents.unsqueeze(0) + + seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2) + if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: + import math + seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size + + return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len} + + @staticmethod + def masks_like(tensor, zero=False, generator=None, p=0.2): + assert isinstance(tensor, list) + out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + + if zero: + if generator is not None: + for u, v in zip(out1, out2): + random_num = torch.rand(1, generator=generator, device=generator.device).item() + if random_num < p: + u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp() + v[:, 0] = torch.zeros_like(v[:, 0]) + else: + u[:, 0] = u[:, 0] + v[:, 0] = v[:, 0] + else: + for u, v in zip(out1, out2): + u[:, 0] = torch.zeros_like(u[:, 0]) + v[:, 0] = torch.zeros_like(v[:, 0]) + + return out1, out2 + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): @@ -1112,9 +1161,16 @@ def model_fn_wan_video( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) - - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) - t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + if dit.is_5b and "mask_5b" in kwargs: + temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten() + temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) + timestep = temp_ts.unsqueeze(0).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index 93ac975..b737b47 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -6,13 +6,6 @@ from modelscope import snapshot_download from diffsynth.models.utils import load_state_dict, hash_state_dict_keys from modelscope import dataset_snapshot_download -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] -) - - pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -26,9 +19,27 @@ pipe.enable_vram_management() # Text-to-video video = pipe( - prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=True, - height=704, width=1280, + seed=0, tiled=False, + height=704, width=1248, + num_frames=121, ) save_video(video, "video1.mp4", fps=15, quality=5) + +# Image-to-video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + height=704, width=1248, + input_image=input_image, + num_frames=121, +) +save_video(video, "video2.mp4", fps=15, quality=5) From 3aed244c6f119c997679e305c8f02c606ae9460b Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 23 Jul 2025 11:20:06 +0800 Subject: [PATCH 3/7] update variable --- diffsynth/models/wan_video_dit.py | 6 +++--- diffsynth/pipelines/wan_video_new.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index b7df3ff..1106669 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -287,14 +287,14 @@ class WanModel(torch.nn.Module): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, - is_5b: bool = False, + seperated_timestep: bool = False, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size - self.is_5b = is_5b + self.seperated_timestep = seperated_timestep self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -685,7 +685,7 @@ class WanModelStateDictConverter: "num_heads": 24, "num_layers": 30, "eps": 1e-6, - "is_5b": True, + "seperated_timestep": True, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 1136403..7ea0e3c 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -237,7 +237,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), - WanVideoUnit_ImageEmbedder5B(), + WanVideoUnit_ImageVaeEmbedder(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -737,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.dit.is_5b: + if input_image is None or pipe.dit.seperated_timestep: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -766,7 +766,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): return {"clip_feature": clip_context, "y": y} -class WanVideoUnit_ImageEmbedder5B(PipelineUnit): +class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -774,7 +774,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or not pipe.dit.is_5b: + if input_image is None or not pipe.dit.seperated_timestep: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) @@ -789,7 +789,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit): import math seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size - return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len} + return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len} @staticmethod def masks_like(tensor, zero=False, generator=None, p=0.2): @@ -1162,8 +1162,8 @@ def model_fn_wan_video( get_sequence_parallel_world_size, get_sp_group) - if dit.is_5b and "mask_5b" in kwargs: - temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten() + if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs: + temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten() temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) timestep = temp_ts.unsqueeze(0).flatten() t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) From 9015d08927331a7b5b559ed17412558279690c33 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 25 Jul 2025 17:09:53 +0800 Subject: [PATCH 4/7] support wan2.2 A14B I2V&T2V --- diffsynth/configs/model_config.py | 2 + diffsynth/models/wan_video_dit.py | 19 ++++ diffsynth/pipelines/wan_video_new.py | 96 ++++++++++++++++++- .../model_inference/Wan2.2-I2V-A14B.py | 32 +++++++ .../model_inference/Wan2.2-T2V-A14B.py | 27 ++++++ .../model_inference/Wan2.2-TI2V-5B.py | 8 +- 6 files changed, 175 insertions(+), 9 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 0903f79..6448fbf 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -141,6 +141,8 @@ model_loader_configs = [ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), + (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), + (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 1106669..3262057 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -352,6 +352,7 @@ class WanModel(torch.nn.Module): context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + fused_y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, @@ -365,6 +366,8 @@ class WanModel(torch.nn.Module): x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + if fused_y is not None: + x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) x, (f, h, w) = self.patchify(x) @@ -673,6 +676,7 @@ class WanModelStateDictConverter: "in_dim_control_adapter": 24, } elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": + # Wan-AI/Wan2.2-TI2V-5B config = { "has_image_input": False, "patch_size": [1, 2, 2], @@ -687,6 +691,21 @@ class WanModelStateDictConverter: "eps": 1e-6, "seperated_timestep": True, } + elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": + # Wan-AI/Wan2.2-I2V-A14B + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + } else: config = {} return state_dict, config diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 7ea0e3c..d108ad4 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -226,10 +226,11 @@ class WanVideoPipeline(BasePipeline): self.text_encoder: WanTextEncoder = None self.image_encoder: WanImageEncoder = None self.dit: WanModel = None + self.dit2: WanModel = None self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None - self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), @@ -238,6 +239,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), WanVideoUnit_ImageVaeEmbedder(), + WanVideoUnit_ImageEmbedderNoClip(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -329,6 +331,37 @@ class WanVideoPipeline(BasePipeline): ), vram_limit=vram_limit, ) + if self.dit2 is not None: + dtype = next(iter(self.dit2.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.dit2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) if self.vae is not None: dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( @@ -427,6 +460,10 @@ class WanVideoPipeline(BasePipeline): for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @@ -473,6 +510,9 @@ class WanVideoPipeline(BasePipeline): # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") pipe.dit = model_manager.fetch_model("wan_video_dit") + num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"]) + if num_dits == 2: + pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1] pipe.vae = model_manager.fetch_model("wan_video_vae") pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") @@ -523,6 +563,8 @@ class WanVideoPipeline(BasePipeline): # Classifier-free guidance cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, + # Boundary + boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, @@ -575,8 +617,12 @@ class WanVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # switch high_noise DiT to low_noise DiT + if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: + print("switching to low noise DiT") + self.load_models_to_device(["dit2", "motion_controller", "vace"]) + models["dit"] = models.pop("dit2") timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: @@ -737,7 +783,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.dit.seperated_timestep: + if input_image is None or pipe.image_encoder is None: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -767,6 +813,9 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ def __init__(self): super().__init__( input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -815,6 +864,42 @@ class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): return out1, out2 +class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): + """ + Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"fused_y": y} + + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( @@ -1116,6 +1201,7 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + fused_y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, @@ -1181,11 +1267,13 @@ def model_fn_wan_video( x = torch.concat([x] * context.shape[0], dim=0) if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) - + if dit.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + if fused_y is not None: + x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..9782a2c --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..de9ae5f --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -0,0 +1,27 @@ +import torch +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import snapshot_download + +snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B") + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index b737b47..f41a941 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -1,17 +1,15 @@ import torch from PIL import Image -from diffsynth import save_video, VideoData +from diffsynth import save_video from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import snapshot_download -from diffsynth.models.utils import load_state_dict, hash_state_dict_keys from modelscope import dataset_snapshot_download pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="model_shards/model-*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), ], ) From bba44173d2a838c0c5d7df1a33729aedc1ac5839 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 25 Jul 2025 17:24:42 +0800 Subject: [PATCH 5/7] minor fix --- diffsynth/configs/model_config.py | 1 - diffsynth/pipelines/wan_video_new.py | 1 - 2 files changed, 2 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 6448fbf..6665599 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -141,7 +141,6 @@ model_loader_configs = [ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), - (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index d108ad4..f27382d 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -619,7 +619,6 @@ class WanVideoPipeline(BasePipeline): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): # switch high_noise DiT to low_noise DiT if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: - print("switching to low noise DiT") self.load_models_to_device(["dit2", "motion_controller", "vace"]) models["dit"] = models.pop("dit2") timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) From 5f68727ad31c07ea3ce45f6dab2787854a711a19 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 28 Jul 2025 11:00:54 +0800 Subject: [PATCH 6/7] refine code --- diffsynth/models/model_manager.py | 21 ++- diffsynth/models/wan_video_dit.py | 13 +- diffsynth/pipelines/wan_video_new.py | 155 +++++++++--------- diffsynth/trainers/utils.py | 2 + examples/wanvideo/README.md | 3 + examples/wanvideo/README_zh.md | 3 + .../model_inference/Wan2.2-I2V-A14B.py | 2 +- .../model_inference/Wan2.2-T2V-A14B.py | 3 - .../model_inference/Wan2.2-TI2V-5B.py | 2 +- .../model_training/full/Wan2.2-I2V-A14B.sh | 35 ++++ .../model_training/full/Wan2.2-T2V-A14B.sh | 31 ++++ .../model_training/full/Wan2.2-TI2V-5B.sh | 14 ++ .../model_training/lora/Wan2.2-I2V-A14B.sh | 37 +++++ .../model_training/lora/Wan2.2-T2V-A14B.sh | 36 ++++ .../model_training/lora/Wan2.2-TI2V-5B.sh | 16 ++ examples/wanvideo/model_training/train.py | 8 + .../validate_full/Wan2.2-I2V-A14B.py | 33 ++++ .../validate_full/Wan2.2-T2V-A14B.py | 28 ++++ .../validate_full/Wan2.2-TI2V-5B.py | 30 ++++ .../validate_lora/Wan2.2-I2V-A14B.py | 31 ++++ .../validate_lora/Wan2.2-T2V-A14B.py | 28 ++++ .../validate_lora/Wan2.2-TI2V-5B.py | 29 ++++ 22 files changed, 474 insertions(+), 86 deletions(-) create mode 100644 examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7ae3c50..d46eedf 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -426,7 +426,7 @@ class ModelManager: self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) - def fetch_model(self, model_name, file_path=None, require_model_path=False): + def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None): fetched_models = [] fetched_model_paths = [] for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): @@ -440,12 +440,25 @@ class ModelManager: return None if len(fetched_models) == 1: print(f"Using {model_name} from {fetched_model_paths[0]}.") + model = fetched_models[0] + path = fetched_model_paths[0] else: - print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + if index is None: + model = fetched_models[0] + path = fetched_model_paths[0] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + elif isinstance(index, int): + model = fetched_models[:index] + path = fetched_model_paths[:index] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.") + else: + model = fetched_models + path = fetched_model_paths + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.") if require_model_path: - return fetched_models[0], fetched_model_paths[0] + return model, path else: - return fetched_models[0] + return model def to(self, device): diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 3262057..ea473b0 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -288,6 +288,9 @@ class WanModel(torch.nn.Module): add_control_adapter: bool = False, in_dim_control_adapter: int = 24, seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, ): super().__init__() self.dim = dim @@ -295,6 +298,9 @@ class WanModel(torch.nn.Module): self.has_image_input = has_image_input self.patch_size = patch_size self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -352,7 +358,6 @@ class WanModel(torch.nn.Module): context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, - fused_y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, @@ -366,8 +371,6 @@ class WanModel(torch.nn.Module): x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - if fused_y is not None: - x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) x, (f, h, w) = self.patchify(x) @@ -690,6 +693,9 @@ class WanModelStateDictConverter: "num_layers": 30, "eps": 1e-6, "seperated_timestep": True, + "require_clip_embedding": False, + "require_vae_embedding": False, + "fuse_vae_embedding_in_latents": True, } elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": # Wan-AI/Wan2.2-I2V-A14B @@ -705,6 +711,7 @@ class WanModelStateDictConverter: "num_heads": 40, "num_layers": 40, "eps": 1e-6, + "require_clip_embedding": False, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index f27382d..9963aa1 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -230,16 +230,17 @@ class WanVideoPipeline(BasePipeline): self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None - self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace") + self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), - WanVideoUnit_ImageEmbedder(), - WanVideoUnit_ImageVaeEmbedder(), - WanVideoUnit_ImageEmbedderNoClip(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -259,7 +260,9 @@ class WanVideoPipeline(BasePipeline): def training_loss(self, **inputs): - timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) @@ -517,6 +520,11 @@ class WanVideoPipeline(BasePipeline): pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 # Initialize tokenizer tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download) @@ -564,7 +572,7 @@ class WanVideoPipeline(BasePipeline): cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, # Boundary - boundary: Optional[float] = 0.875, + switch_DiT_boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, @@ -617,11 +625,14 @@ class WanVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - # switch high_noise DiT to low_noise DiT - if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: - self.load_models_to_device(["dit2", "motion_controller", "vace"]) - models["dit"] = models.pop("dit2") + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + + # Timestep timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: @@ -775,6 +786,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit): class WanVideoUnit_ImageEmbedder(PipelineUnit): + """ + Deprecated + """ def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -811,70 +825,38 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): return {"clip_feature": clip_context, "y": y} -class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): - """ - Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. - """ + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae") + input_params=("input_image", "end_image", "height", "width"), + onload_model_names=("image_encoder",) ) - def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or not pipe.dit.seperated_timestep: + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: return {} pipe.load_models_to_device(self.onload_model_names) - image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) - z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + - _, mask2 = self.masks_like([noise.squeeze(0)], zero=True) - latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0) - latents = latents.unsqueeze(0) - seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2) - if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: - import math - seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size - - return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len} - - @staticmethod - def masks_like(tensor, zero=False, generator=None, p=0.2): - assert isinstance(tensor, list) - out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] - out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] - - if zero: - if generator is not None: - for u, v in zip(out1, out2): - random_num = torch.rand(1, generator=generator, device=generator.device).item() - if random_num < p: - u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp() - v[:, 0] = torch.zeros_like(v[:, 0]) - else: - u[:, 0] = u[:, 0] - v[:, 0] = v[:, 0] - else: - for u, v in zip(out1, out2): - u[:, 0] = torch.zeros_like(u[:, 0]) - v[:, 0] = torch.zeros_like(v[:, 0]) - - return out1, out2 - - -class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): - """ - Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B. - """ +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep: + if input_image is None or not pipe.dit.require_vae_embedding: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -896,14 +878,36 @@ class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): y = torch.concat([msk, y]) y = y.unsqueeze(0) y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - return {"fused_y": y} + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True} + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): @@ -927,7 +931,7 @@ class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( input_params=("reference_image", "height", "width", "reference_image"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, reference_image, height, width): @@ -1200,7 +1204,6 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, - fused_y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, @@ -1213,6 +1216,7 @@ def model_fn_wan_video( use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -1247,15 +1251,19 @@ def model_fn_wan_video( get_sequence_parallel_world_size, get_sp_group) - if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs: - temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten() - temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) - timestep = temp_ts.unsqueeze(0).flatten() - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) else: t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -1267,16 +1275,15 @@ def model_fn_wan_video( if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) - if dit.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - if fused_y is not None: - x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) - # Reference image if reference_latents is not None: diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b171857..b27fc34 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -434,6 +434,8 @@ def wan_parser(): parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") return parser diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 7b2ebbd..c75e518 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## Model Inference diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 860ff83..822c8e6 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## 模型推理 diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py index 9782a2c..0c1be54 100644 --- a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -22,7 +22,7 @@ dataset_snapshot_download( allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] ) input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) -# Text-to-video + video = pipe( prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py index de9ae5f..27b10d0 100644 --- a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -1,9 +1,6 @@ import torch from diffsynth import save_video from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import snapshot_download - -snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B") pipe = WanVideoPipeline.from_pretrained( diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index f41a941..50d81c2 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -10,7 +10,7 @@ pipe = WanVideoPipeline.from_pretrained( model_configs=[ ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), ], ) pipe.enable_vram_management() diff --git a/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..2f531e7 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh @@ -0,0 +1,35 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..f634117 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh @@ -0,0 +1,31 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..def9f89 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..4201b47 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh @@ -0,0 +1,37 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..737896c --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh @@ -0,0 +1,36 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..6a33b57 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 877c5b8..93ec0bd 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule): use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, + max_timestep_boundary=1.0, + min_timestep_boundary=0.0, ): super().__init__() # Load models @@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.max_timestep_boundary = max_timestep_boundary + self.min_timestep_boundary = min_timestep_boundary def forward_preprocess(self, data): @@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule): "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "cfg_merge": False, "vace_scale": 1, + "max_timestep_boundary": self.max_timestep_boundary, + "min_timestep_boundary": self.min_timestep_boundary, } # Extra inputs @@ -106,6 +112,8 @@ if __name__ == "__main__": lora_rank=args.lora_rank, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, + max_timestep_boundary=args.max_timestep_boundary, + min_timestep_boundary=args.min_timestep_boundary, ) model_logger = ModelLogger( args.output_path, diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..ddcdf5c --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..be0e000 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..0f0ea5d --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..4a6bd9c --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..ab43927 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..d5a9229 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) From 05dba91f79f4b4195100b2a61e407c263ae21265 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 28 Jul 2025 13:38:01 +0800 Subject: [PATCH 7/7] fix wan2.2 5B --- diffsynth/pipelines/wan_video_new.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9963aa1..1fa930d 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -646,6 +646,8 @@ class WanVideoPipeline(BasePipeline): # Scheduler inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] # VACE (TODO: remove it) if vace_reference_image is not None: @@ -899,7 +901,7 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit): image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) latents[:, :, 0: 1] = z - return {"latents": latents, "fuse_vae_embedding_in_latents": True} + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}