mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
support LongCat-Video
This commit is contained in:
@@ -22,6 +22,7 @@ from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
@@ -71,6 +72,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
WanVideoUnit_LongCatVideo(),
|
||||
]
|
||||
self.post_units = [
|
||||
WanVideoPostUnit_S2V(),
|
||||
@@ -150,6 +152,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit is not None:
|
||||
from ..models.longcat_video_dit import LayerNorm_FP32, RMSNorm_FP32
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
@@ -162,6 +165,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Conv1d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
LayerNorm_FP32: AutoWrappedModule,
|
||||
RMSNorm_FP32: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -467,6 +472,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
# Speed control
|
||||
motion_bucket_id: Optional[int] = None,
|
||||
# LongCat-Video
|
||||
longcat_video: Optional[list[Image.Image]] = None,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||
@@ -504,6 +511,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"sigma_shift": sigma_shift,
|
||||
"motion_bucket_id": motion_bucket_id,
|
||||
"longcat_video": longcat_video,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
@@ -1151,6 +1159,22 @@ class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
|
||||
return {"y": y}
|
||||
|
||||
|
||||
class WanVideoUnit_LongCatVideo(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("longcat_video",),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, longcat_video):
|
||||
if longcat_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
longcat_video = pipe.preprocess_video(longcat_video)
|
||||
longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"longcat_latents": longcat_latents}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -1279,6 +1303,7 @@ def model_fn_wan_video(
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
pose_latents=None,
|
||||
face_pixel_values=None,
|
||||
longcat_latents=None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
@@ -1313,6 +1338,18 @@ def model_fn_wan_video(
|
||||
tensor_names=["latents", "y"],
|
||||
batch_size=2 if cfg_merge else 1
|
||||
)
|
||||
# LongCat-Video
|
||||
if isinstance(dit, LongCatVideoTransformer3DModel):
|
||||
return model_fn_longcat_video(
|
||||
dit=dit,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
longcat_latents=longcat_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
# wan2.2 s2v
|
||||
if audio_embeds is not None:
|
||||
return model_fn_wans2v(
|
||||
@@ -1468,6 +1505,36 @@ def model_fn_wan_video(
|
||||
return x
|
||||
|
||||
|
||||
def model_fn_longcat_video(
|
||||
dit: LongCatVideoTransformer3DModel,
|
||||
latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
longcat_latents: torch.Tensor = None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
if longcat_latents is not None:
|
||||
latents[:, :, :longcat_latents.shape[2]] = longcat_latents
|
||||
num_cond_latents = longcat_latents.shape[2]
|
||||
else:
|
||||
num_cond_latents = 0
|
||||
context = context.unsqueeze(0)
|
||||
encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64)
|
||||
output = dit(
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
encoder_attention_mask,
|
||||
num_cond_latents=num_cond_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
output = -output
|
||||
output = output.to(latents.dtype)
|
||||
return output
|
||||
|
||||
|
||||
def model_fn_wans2v(
|
||||
dit,
|
||||
latents,
|
||||
|
||||
Reference in New Issue
Block a user