mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
update variable
This commit is contained in:
@@ -287,14 +287,14 @@ class WanModel(torch.nn.Module):
|
|||||||
has_ref_conv: bool = False,
|
has_ref_conv: bool = False,
|
||||||
add_control_adapter: bool = False,
|
add_control_adapter: bool = False,
|
||||||
in_dim_control_adapter: int = 24,
|
in_dim_control_adapter: int = 24,
|
||||||
is_5b: bool = False,
|
seperated_timestep: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.freq_dim = freq_dim
|
self.freq_dim = freq_dim
|
||||||
self.has_image_input = has_image_input
|
self.has_image_input = has_image_input
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.is_5b = is_5b
|
self.seperated_timestep = seperated_timestep
|
||||||
|
|
||||||
self.patch_embedding = nn.Conv3d(
|
self.patch_embedding = nn.Conv3d(
|
||||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
@@ -685,7 +685,7 @@ class WanModelStateDictConverter:
|
|||||||
"num_heads": 24,
|
"num_heads": 24,
|
||||||
"num_layers": 30,
|
"num_layers": 30,
|
||||||
"eps": 1e-6,
|
"eps": 1e-6,
|
||||||
"is_5b": True,
|
"seperated_timestep": True,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoUnit_InputVideoEmbedder(),
|
WanVideoUnit_InputVideoEmbedder(),
|
||||||
WanVideoUnit_PromptEmbedder(),
|
WanVideoUnit_PromptEmbedder(),
|
||||||
WanVideoUnit_ImageEmbedder(),
|
WanVideoUnit_ImageEmbedder(),
|
||||||
WanVideoUnit_ImageEmbedder5B(),
|
WanVideoUnit_ImageVaeEmbedder(),
|
||||||
WanVideoUnit_FunControl(),
|
WanVideoUnit_FunControl(),
|
||||||
WanVideoUnit_FunReference(),
|
WanVideoUnit_FunReference(),
|
||||||
WanVideoUnit_FunCameraControl(),
|
WanVideoUnit_FunCameraControl(),
|
||||||
@@ -737,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
if input_image is None or pipe.dit.is_5b:
|
if input_image is None or pipe.dit.seperated_timestep:
|
||||||
return {}
|
return {}
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||||
@@ -766,7 +766,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
|||||||
return {"clip_feature": clip_context, "y": y}
|
return {"clip_feature": clip_context, "y": y}
|
||||||
|
|
||||||
|
|
||||||
class WanVideoUnit_ImageEmbedder5B(PipelineUnit):
|
class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||||
@@ -774,7 +774,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride):
|
def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
if input_image is None or not pipe.dit.is_5b:
|
if input_image is None or not pipe.dit.seperated_timestep:
|
||||||
return {}
|
return {}
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device)
|
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device)
|
||||||
@@ -789,7 +789,7 @@ class WanVideoUnit_ImageEmbedder5B(PipelineUnit):
|
|||||||
import math
|
import math
|
||||||
seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size
|
seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size
|
||||||
|
|
||||||
return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len}
|
return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def masks_like(tensor, zero=False, generator=None, p=0.2):
|
def masks_like(tensor, zero=False, generator=None, p=0.2):
|
||||||
@@ -1162,8 +1162,8 @@ def model_fn_wan_video(
|
|||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
get_sp_group)
|
get_sp_group)
|
||||||
|
|
||||||
if dit.is_5b and "mask_5b" in kwargs:
|
if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs:
|
||||||
temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten()
|
temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten()
|
||||||
temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep])
|
temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep])
|
||||||
timestep = temp_ts.unsqueeze(0).flatten()
|
timestep = temp_ts.unsqueeze(0).flatten()
|
||||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"])))
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"])))
|
||||||
|
|||||||
Reference in New Issue
Block a user