From 3aed244c6f119c997679e305c8f02c606ae9460b Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 23 Jul 2025 11:20:06 +0800 Subject: [PATCH] update variable --- diffsynth/models/wan_video_dit.py | 6 +++--- diffsynth/pipelines/wan_video_new.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index b7df3ff..1106669 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -287,14 +287,14 @@ class WanModel(torch.nn.Module): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, - is_5b: bool = False, + seperated_timestep: bool = False, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size - self.is_5b = is_5b + self.seperated_timestep = seperated_timestep self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -685,7 +685,7 @@ class WanModelStateDictConverter: "num_heads": 24, "num_layers": 30, "eps": 1e-6, - "is_5b": True, + "seperated_timestep": True, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 1136403..7ea0e3c 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -237,7 +237,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), - WanVideoUnit_ImageEmbedder5B(), + WanVideoUnit_ImageVaeEmbedder(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -737,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.dit.is_5b: + if input_image is None or pipe.dit.seperated_timestep: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -766,7 +766,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): return {"clip_feature": clip_context, "y": y} -class WanVideoUnit_ImageEmbedder5B(PipelineUnit): +class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -774,7 +774,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or not pipe.dit.is_5b: + if input_image is None or not pipe.dit.seperated_timestep: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) @@ -789,7 +789,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit): import math seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size - return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len} + return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len} @staticmethod def masks_like(tensor, zero=False, generator=None, p=0.2): @@ -1162,8 +1162,8 @@ def model_fn_wan_video( get_sequence_parallel_world_size, get_sp_group) - if dit.is_5b and "mask_5b" in kwargs: - temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten() + if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs: + temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten() temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) timestep = temp_ts.unsqueeze(0).flatten() t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"])))