mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
Merge pull request #1034 from modelscope/video_as_prompt
Video as prompt
This commit is contained in:
@@ -22,6 +22,7 @@ from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..models.wan_video_mot import MotWanModel
|
||||
from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
@@ -47,9 +48,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.vace2: VaceWanModel = None
|
||||
self.vap: MotWanModel = None
|
||||
self.animate_adapter: WanAnimateAdapter = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
@@ -69,6 +71,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoPostUnit_AnimatePoseLatents(),
|
||||
WanVideoPostUnit_AnimateFacePixelValues(),
|
||||
WanVideoPostUnit_AnimateInpaint(),
|
||||
WanVideoUnit_VAP(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
@@ -392,6 +395,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
vace = model_manager.fetch_model("wan_video_vace", index=2)
|
||||
pipe.vap = model_manager.fetch_model("wan_video_vap")
|
||||
if isinstance(vace, list):
|
||||
pipe.vace, pipe.vace2 = vace
|
||||
else:
|
||||
@@ -455,6 +459,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
animate_face_video: Optional[list[Image.Image]] = None,
|
||||
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
||||
animate_mask_video: Optional[list[Image.Image]] = None,
|
||||
# VAP
|
||||
vap_video: Optional[list[Image.Image]] = None,
|
||||
vap_prompt: Optional[str] = " ",
|
||||
negative_vap_prompt: Optional[str] = " ",
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
@@ -493,10 +501,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"vap_prompt": vap_prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
"negative_vap_prompt": negative_vap_prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_shared = {
|
||||
@@ -516,6 +526,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||
"vap_video": vap_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -927,6 +938,71 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
else:
|
||||
return {"vace_context": None, "vace_scale": vace_scale}
|
||||
|
||||
class WanVideoUnit_VAP(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder", "vae", "image_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("vap_video") is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
else:
|
||||
# 1. encode vap prompt
|
||||
pipe.load_models_to_device(["text_encoder"])
|
||||
vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "")
|
||||
vap_prompt_emb = pipe.prompter.encode_prompt(vap_prompt, positive=inputs_posi.get('positive',None), device=pipe.device)
|
||||
negative_vap_prompt_emb = pipe.prompter.encode_prompt(negative_vap_prompt, positive=inputs_nega.get('positive',None), device=pipe.device)
|
||||
inputs_posi.update({"context_vap":vap_prompt_emb})
|
||||
inputs_nega.update({"context_vap":negative_vap_prompt_emb})
|
||||
# 2. prepare vap image clip embedding
|
||||
pipe.load_models_to_device(["vae", "image_encoder"])
|
||||
vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image")
|
||||
|
||||
num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1)
|
||||
|
||||
image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device)
|
||||
|
||||
vap_clip_context = pipe.image_encoder.encode_image([image_vap])
|
||||
if end_image is not None:
|
||||
vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
|
||||
if pipe.dit.has_image_pos_emb:
|
||||
vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1)
|
||||
vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inputs_shared.update({"vap_clip_feature":vap_clip_context})
|
||||
|
||||
# 3. prepare vap latents
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
msk[:, -1:] = 1
|
||||
last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1)
|
||||
else:
|
||||
vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
vap_video = pipe.preprocess_video(vap_video)
|
||||
vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inputs_shared.update({"vap_hidden_state":vap_latent})
|
||||
pipe.load_models_to_device([])
|
||||
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||
@@ -1285,6 +1361,7 @@ def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
vap: MotWanModel = None,
|
||||
animate_adapter: WanAnimateAdapter = None,
|
||||
latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
@@ -1297,6 +1374,9 @@ def model_fn_wan_video(
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
vap_hidden_state = None,
|
||||
vap_clip_feature = None,
|
||||
context_vap = None,
|
||||
drop_motion_frames: bool = True,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
@@ -1406,7 +1486,7 @@ def model_fn_wan_video(
|
||||
if clip_feature is not None and dit.require_clip_embedding:
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
|
||||
# Camera control
|
||||
x = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
@@ -1431,6 +1511,25 @@ def model_fn_wan_video(
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# VAP
|
||||
if vap is not None:
|
||||
# hidden state
|
||||
x_vap = vap_hidden_state
|
||||
x_vap = vap.patchify(x_vap)
|
||||
x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous()
|
||||
# Timestep
|
||||
clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype)
|
||||
t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep))
|
||||
t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim))
|
||||
|
||||
# rope
|
||||
freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device)
|
||||
|
||||
# context
|
||||
vap_clip_embedding = vap.img_emb(vap_clip_feature)
|
||||
context_vap = vap.text_embedding(context_vap)
|
||||
context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
@@ -1460,23 +1559,45 @@ def model_fn_wan_video(
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
def create_custom_forward_vap(block, vap):
|
||||
def custom_forward(*inputs):
|
||||
return vap(block, *inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
# Block
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
if vap is not None and block_id in vap.mot_layers_mapping:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
|
||||
else:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
# VACE
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
|
||||
Reference in New Issue
Block a user