From dbef6122e9738a09a747556d10f07df9bb52d398 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 5 May 2025 23:23:06 +0800 Subject: [PATCH] ... --- diffsynth/pipelines/wan_video_new.py | 38 ++++++++++++++++++++-------- diffsynth/vram_management/layers.py | 35 +++++++++++++++++++++++++ test.py | 34 +++++++++++++++---------- 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index dcc4485..de05e50 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -24,7 +24,7 @@ from PIL import Image from tqdm import tqdm from typing import Optional -from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear +from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample @@ -188,8 +188,8 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), - WanVideoUnit_FunReference(), WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), WanVideoUnit_SpeedControl(), WanVideoUnit_VACE(), WanVideoUnit_TeaCache(), @@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline): module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, - torch.nn.LayerNorm: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, RMSNorm: AutoWrappedModule, torch.nn.Conv2d: AutoWrappedModule, }, @@ -654,7 +654,7 @@ class WanVideoUnit_FunControl(PipelineUnit): class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( - input_params=("reference_image", "height", "width"), + input_params=("reference_image", "height", "width", "reference_image"), onload_model_names=("vae") ) @@ -663,9 +663,11 @@ class WanVideoUnit_FunReference(PipelineUnit): return {} pipe.load_models_to_device(["vae"]) reference_image = reference_image.resize((width, height)) - reference_image = pipe.preprocess_video([reference_image]) - reference_latents = pipe.vae.encode(reference_image, device=pipe.device) - return {"reference_latents": reference_latents} + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} @@ -753,11 +755,19 @@ class WanVideoUnit_TeaCache(PipelineUnit): class WanVideoUnit_CfgMerger(PipelineUnit): def __init__(self): super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): if not inputs_shared["cfg_merge"]: return inputs_shared, inputs_posi, inputs_nega - inputs_shared["context"] = torch.concat((inputs_posi["context"], inputs_nega["context"]), dim=0) + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) inputs_posi.clear() inputs_nega.clear() return inputs_shared, inputs_posi, inputs_nega @@ -835,10 +845,12 @@ class TemporalTiler_BCTHW: mask = repeat(t, "T -> 1 1 T 1 1") return mask - def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names): + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) @@ -881,6 +893,7 @@ def model_fn_wan_video( motion_bucket_id: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -905,7 +918,8 @@ def model_fn_wan_video( sliding_window_size, sliding_window_stride, latents.device, latents.dtype, model_kwargs=model_kwargs, - tensor_names=["latents", "y"] + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 ) if use_unified_sequence_parallel: @@ -936,7 +950,9 @@ def model_fn_wan_video( # Reference image if reference_latents is not None: - reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2) + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) x = torch.concat([reference_latents, x], dim=1) f += 1 diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index a9df39e..aa2bda2 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -38,6 +38,41 @@ class AutoWrappedModule(torch.nn.Module): return module(*args, **kwargs) +class WanAutoCastLayerNorm(torch.nn.LayerNorm): + def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): + with init_weights_on_device(device=torch.device("meta")): + super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) + self.weight = module.weight + self.bias = module.bias + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.state = 0 + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def forward(self, x, *args, **kwargs): + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + else: + weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + with torch.amp.autocast(device_type=x.device.type): + x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x) + return x + + class AutoWrappedLinear(torch.nn.Linear): def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): with init_weights_on_device(device=torch.device("meta")): diff --git a/test.py b/test.py index f16ae0e..f7959ee 100644 --- a/test.py +++ b/test.py @@ -1,7 +1,9 @@ import torch +torch.cuda.set_per_process_memory_fraction(0.999, 0) from diffsynth import ModelManager, save_video, VideoData, save_frames, save_video, download_models -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, model_fn_wan_video from diffsynth.controlnets.processors import Annotator +from diffsynth.data.video import crop_and_resize from modelscope import snapshot_download from tqdm import tqdm from PIL import Image @@ -13,28 +15,32 @@ pipe = WanVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + # ModelConfig("D:\projects\VideoX-Fun\models\Wan2.1-Fun-V1.1-1.3B-Control\diffusion_pytorch_model.safetensors", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), ], ) -pipe.enable_vram_management(num_persistent_param_in_dit=10*10**9) - +pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) video = VideoData(rf"D:\pr_projects\20250503_dance\data\双马尾竖屏暴击!你的微笑就是彩虹的微笑♥ - 1.双马尾竖屏暴击!你的微笑就是彩虹的微笑♥(Av114086629088385,P1).mp4", height=832, width=480) annotator = Annotator("openpose") -video = [video[i] for i in tqdm(range(450, 450+1*17, 1))] +video = [video[i] for i in tqdm(range(450, 450+1*81, 1))] save_video(video, "video_input.mp4", fps=60, quality=5) control_video = [annotator(f) for f in tqdm(video)] save_video(control_video, "video_control.mp4", fps=60, quality=5) -reference_image = Image.open(rf"D:\pr_projects\20250503_dance\data\marmot.png").resize((480, 832)) +reference_image = crop_and_resize(Image.open(rf"D:\pr_projects\20250503_dance\data\marmot4.png"), 832, 480) -video = pipe( - prompt="微距摄影风格特写画面,一只憨态可掬的土拨鼠正用后腿站立在碎石堆上,它在挥舞着双臂。金棕色的绒毛在阳光下泛着丝绸般的光泽,腹部毛发呈现浅杏色渐变,每根毛尖都闪烁着细密的光晕。两只黑曜石般的眼睛透出机警而温顺的光芒,鼻梁两侧的白色触须微微颤动,捕捉着空气中的气息。背景是虚化的灰绿色渐变,几簇嫩绿苔藓从画面右下角探出头来,与前景散落的鹅卵石形成微妙的景深对比。土拨鼠圆润的身形在逆光中勾勒出柔和的轮廓,耳朵紧贴头部的姿态流露出戒备中的天真,整个画面洋溢着自然界生灵特有的灵动与纯真。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=True, - height=832, width=480, num_frames=len(control_video), - control_video=control_video, reference_image=reference_image, - # num_inference_steps=30, cfg_scale=1, -) -save_video(video, "video1.mp4", fps=60, quality=5) +with torch.amp.autocast("cuda", torch.bfloat16): + video = pipe( + prompt="微距摄影风格特写画面,一只憨态可掬的土拨鼠正用后腿站立在碎石堆上,它在挥舞着双臂。金棕色的绒毛在阳光下泛着丝绸般的光泽,腹部毛发呈现浅杏色渐变,每根毛尖都闪烁着细密的光晕。两只黑曜石般的眼睛透出机警而温顺的光芒,鼻梁两侧的白色触须微微颤动,捕捉着空气中的气息。背景是虚化的灰绿色渐变,几簇嫩绿苔藓从画面右下角探出头来,与前景散落的鹅卵石形成微妙的景深对比。土拨鼠圆润的身形在逆光中勾勒出柔和的轮廓,耳朵紧贴头部的姿态流露出戒备中的天真,整个画面洋溢着自然界生灵特有的灵动与纯真。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=43, tiled=True, + height=832, width=480, num_frames=len(control_video), + control_video=control_video, reference_image=reference_image, + # sliding_window_size=5, sliding_window_stride=2, + # num_inference_steps=100, + # cfg_merge=True, + sigma_shift=16, + ) + save_video(video, "video1.mp4", fps=60, quality=5)