mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
vace
This commit is contained in:
@@ -4,6 +4,7 @@ from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import WanPrompter
|
||||
@@ -33,7 +34,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
|
||||
self.vace: VaceWanModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
@@ -153,6 +155,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
self.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -253,6 +256,57 @@ class WanVideoPipeline(BasePipeline):
|
||||
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
def prepare_vace_kwargs(
|
||||
self,
|
||||
latents,
|
||||
vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
|
||||
height=480, width=832, num_frames=81,
|
||||
seed=None, rand_device="cpu",
|
||||
tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
|
||||
):
|
||||
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
|
||||
self.load_models_to_device(["vae"])
|
||||
if vace_video is None:
|
||||
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
vace_video = self.preprocess_images(vace_video)
|
||||
vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
if vace_mask is None:
|
||||
vace_mask = torch.ones_like(vace_video)
|
||||
else:
|
||||
vace_mask = self.preprocess_images(vace_mask)
|
||||
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
||||
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
||||
inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||
|
||||
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
||||
|
||||
if vace_reference_image is None:
|
||||
pass
|
||||
else:
|
||||
vace_reference_image = self.preprocess_images([vace_reference_image])
|
||||
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||
|
||||
noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
|
||||
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = torch.concat((noise, latents), dim=2)
|
||||
|
||||
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
||||
return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
|
||||
else:
|
||||
return latents, {"vace_context": None, "vace_scale": vace_scale}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -264,6 +318,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
end_image=None,
|
||||
input_video=None,
|
||||
control_video=None,
|
||||
vace_video=None,
|
||||
vace_video_mask=None,
|
||||
vace_reference_image=None,
|
||||
vace_scale=1.0,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
@@ -333,6 +391,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# VACE
|
||||
latents, vace_kwargs = self.prepare_vace_kwargs(
|
||||
latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
|
||||
height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
|
||||
)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
@@ -341,23 +405,23 @@ class WanVideoPipeline(BasePipeline):
|
||||
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit", "motion_controller"])
|
||||
self.load_models_to_device(["dit", "motion_controller", "vace"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -365,6 +429,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
if vace_reference_image is not None:
|
||||
latents = latents[:, :, 1:]
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -432,11 +499,14 @@ class TeaCache:
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
x: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
@@ -472,6 +542,9 @@ def model_fn_wan_video(
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if vace_context is not None:
|
||||
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
||||
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
@@ -480,8 +553,10 @@ def model_fn_wan_video(
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
for block in dit.blocks:
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user