mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
...
This commit is contained in:
@@ -24,7 +24,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
@@ -188,8 +188,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedder(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_FunControl(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_SpeedControl(),
|
||||
WanVideoUnit_VACE(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
@@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
@@ -654,7 +654,7 @@ class WanVideoUnit_FunControl(PipelineUnit):
|
||||
class WanVideoUnit_FunReference(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("reference_image", "height", "width"),
|
||||
input_params=("reference_image", "height", "width", "reference_image"),
|
||||
onload_model_names=("vae")
|
||||
)
|
||||
|
||||
@@ -663,9 +663,11 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
return {}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
reference_image = reference_image.resize((width, height))
|
||||
reference_image = pipe.preprocess_video([reference_image])
|
||||
reference_latents = pipe.vae.encode(reference_image, device=pipe.device)
|
||||
return {"reference_latents": reference_latents}
|
||||
reference_latents = pipe.preprocess_video([reference_image])
|
||||
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
||||
clip_feature = pipe.preprocess_image(reference_image)
|
||||
clip_feature = pipe.image_encoder.encode_image([clip_feature])
|
||||
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
|
||||
|
||||
|
||||
|
||||
@@ -753,11 +755,19 @@ class WanVideoUnit_TeaCache(PipelineUnit):
|
||||
class WanVideoUnit_CfgMerger(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if not inputs_shared["cfg_merge"]:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
inputs_shared["context"] = torch.concat((inputs_posi["context"], inputs_nega["context"]), dim=0)
|
||||
for name in self.concat_tensor_names:
|
||||
tensor_posi = inputs_posi.get(name)
|
||||
tensor_nega = inputs_nega.get(name)
|
||||
tensor_shared = inputs_shared.get(name)
|
||||
if tensor_posi is not None and tensor_nega is not None:
|
||||
inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
|
||||
elif tensor_shared is not None:
|
||||
inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0)
|
||||
inputs_posi.clear()
|
||||
inputs_nega.clear()
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -835,10 +845,12 @@ class TemporalTiler_BCTHW:
|
||||
mask = repeat(t, "T -> 1 1 T 1 1")
|
||||
return mask
|
||||
|
||||
def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names):
|
||||
def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None):
|
||||
tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]
|
||||
tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}
|
||||
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
|
||||
if batch_size is not None:
|
||||
B *= batch_size
|
||||
data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype
|
||||
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
|
||||
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
|
||||
@@ -881,6 +893,7 @@ def model_fn_wan_video(
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -905,7 +918,8 @@ def model_fn_wan_video(
|
||||
sliding_window_size, sliding_window_stride,
|
||||
latents.device, latents.dtype,
|
||||
model_kwargs=model_kwargs,
|
||||
tensor_names=["latents", "y"]
|
||||
tensor_names=["latents", "y"],
|
||||
batch_size=2 if cfg_merge else 1
|
||||
)
|
||||
|
||||
if use_unified_sequence_parallel:
|
||||
@@ -936,7 +950,9 @@ def model_fn_wan_video(
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
|
||||
if len(reference_latents.shape) == 5:
|
||||
reference_latents = reference_latents[:, :, 0]
|
||||
reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
|
||||
x = torch.concat([reference_latents, x], dim=1)
|
||||
f += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user