mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
uodate wan2.2-VACE-Fun
This commit is contained in:
@@ -44,8 +44,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.vace2: VaceWanModel = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
@@ -359,6 +360,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||
if isinstance(vace, list):
|
||||
pipe.vace, pipe.vace2 = vace
|
||||
else:
|
||||
pipe.vace = vace
|
||||
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||
|
||||
# Size division factor
|
||||
@@ -481,6 +486,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||
self.load_models_to_device(self.in_iteration_models_2)
|
||||
models["dit"] = self.dit2
|
||||
models["vace"] = self.vace2
|
||||
|
||||
# Timestep
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
@@ -534,11 +540,12 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
|
||||
length = (num_frames - 1) // 4 + 1
|
||||
if vace_reference_image is not None:
|
||||
length += 1
|
||||
f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1
|
||||
length += f
|
||||
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
|
||||
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
||||
if vace_reference_image is not None:
|
||||
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
||||
noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
@@ -849,11 +856,22 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
if vace_reference_image is None:
|
||||
pass
|
||||
else:
|
||||
if not isinstance(vace_reference_image,list):
|
||||
vace_reference_image = [vace_reference_image]
|
||||
|
||||
vace_reference_image = pipe.preprocess_video(vace_reference_image)
|
||||
|
||||
bs, c, f, h, w = vace_reference_image.shape
|
||||
new_vace_ref_images = []
|
||||
for j in range(f):
|
||||
new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])
|
||||
vace_reference_image = new_vace_ref_images
|
||||
|
||||
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_reference_latents = torch.concat((*vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)
|
||||
|
||||
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
||||
return {"vace_context": vace_context, "vace_scale": vace_scale}
|
||||
|
||||
Reference in New Issue
Block a user