mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
...
This commit is contained in:
@@ -24,7 +24,7 @@ from PIL import Image
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||||
@@ -188,8 +188,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoUnit_InputVideoEmbedder(),
|
WanVideoUnit_InputVideoEmbedder(),
|
||||||
WanVideoUnit_PromptEmbedder(),
|
WanVideoUnit_PromptEmbedder(),
|
||||||
WanVideoUnit_ImageEmbedder(),
|
WanVideoUnit_ImageEmbedder(),
|
||||||
WanVideoUnit_FunReference(),
|
|
||||||
WanVideoUnit_FunControl(),
|
WanVideoUnit_FunControl(),
|
||||||
|
WanVideoUnit_FunReference(),
|
||||||
WanVideoUnit_SpeedControl(),
|
WanVideoUnit_SpeedControl(),
|
||||||
WanVideoUnit_VACE(),
|
WanVideoUnit_VACE(),
|
||||||
WanVideoUnit_TeaCache(),
|
WanVideoUnit_TeaCache(),
|
||||||
@@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
module_map = {
|
module_map = {
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
torch.nn.Conv3d: AutoWrappedModule,
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||||
RMSNorm: AutoWrappedModule,
|
RMSNorm: AutoWrappedModule,
|
||||||
torch.nn.Conv2d: AutoWrappedModule,
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
},
|
},
|
||||||
@@ -654,7 +654,7 @@ class WanVideoUnit_FunControl(PipelineUnit):
|
|||||||
class WanVideoUnit_FunReference(PipelineUnit):
|
class WanVideoUnit_FunReference(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("reference_image", "height", "width"),
|
input_params=("reference_image", "height", "width", "reference_image"),
|
||||||
onload_model_names=("vae")
|
onload_model_names=("vae")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -663,9 +663,11 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
|||||||
return {}
|
return {}
|
||||||
pipe.load_models_to_device(["vae"])
|
pipe.load_models_to_device(["vae"])
|
||||||
reference_image = reference_image.resize((width, height))
|
reference_image = reference_image.resize((width, height))
|
||||||
reference_image = pipe.preprocess_video([reference_image])
|
reference_latents = pipe.preprocess_video([reference_image])
|
||||||
reference_latents = pipe.vae.encode(reference_image, device=pipe.device)
|
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
||||||
return {"reference_latents": reference_latents}
|
clip_feature = pipe.preprocess_image(reference_image)
|
||||||
|
clip_feature = pipe.image_encoder.encode_image([clip_feature])
|
||||||
|
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -753,11 +755,19 @@ class WanVideoUnit_TeaCache(PipelineUnit):
|
|||||||
class WanVideoUnit_CfgMerger(PipelineUnit):
|
class WanVideoUnit_CfgMerger(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(take_over=True)
|
super().__init__(take_over=True)
|
||||||
|
self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
if not inputs_shared["cfg_merge"]:
|
if not inputs_shared["cfg_merge"]:
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
inputs_shared["context"] = torch.concat((inputs_posi["context"], inputs_nega["context"]), dim=0)
|
for name in self.concat_tensor_names:
|
||||||
|
tensor_posi = inputs_posi.get(name)
|
||||||
|
tensor_nega = inputs_nega.get(name)
|
||||||
|
tensor_shared = inputs_shared.get(name)
|
||||||
|
if tensor_posi is not None and tensor_nega is not None:
|
||||||
|
inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
|
||||||
|
elif tensor_shared is not None:
|
||||||
|
inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0)
|
||||||
inputs_posi.clear()
|
inputs_posi.clear()
|
||||||
inputs_nega.clear()
|
inputs_nega.clear()
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
@@ -835,10 +845,12 @@ class TemporalTiler_BCTHW:
|
|||||||
mask = repeat(t, "T -> 1 1 T 1 1")
|
mask = repeat(t, "T -> 1 1 T 1 1")
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names):
|
def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None):
|
||||||
tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]
|
tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]
|
||||||
tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}
|
tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}
|
||||||
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
|
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
|
||||||
|
if batch_size is not None:
|
||||||
|
B *= batch_size
|
||||||
data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype
|
data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype
|
||||||
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
|
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
|
||||||
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
|
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
|
||||||
@@ -881,6 +893,7 @@ def model_fn_wan_video(
|
|||||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||||
sliding_window_size: Optional[int] = None,
|
sliding_window_size: Optional[int] = None,
|
||||||
sliding_window_stride: Optional[int] = None,
|
sliding_window_stride: Optional[int] = None,
|
||||||
|
cfg_merge: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||||
@@ -905,7 +918,8 @@ def model_fn_wan_video(
|
|||||||
sliding_window_size, sliding_window_stride,
|
sliding_window_size, sliding_window_stride,
|
||||||
latents.device, latents.dtype,
|
latents.device, latents.dtype,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
tensor_names=["latents", "y"]
|
tensor_names=["latents", "y"],
|
||||||
|
batch_size=2 if cfg_merge else 1
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
@@ -936,7 +950,9 @@ def model_fn_wan_video(
|
|||||||
|
|
||||||
# Reference image
|
# Reference image
|
||||||
if reference_latents is not None:
|
if reference_latents is not None:
|
||||||
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
|
if len(reference_latents.shape) == 5:
|
||||||
|
reference_latents = reference_latents[:, :, 0]
|
||||||
|
reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
|
||||||
x = torch.concat([reference_latents, x], dim=1)
|
x = torch.concat([reference_latents, x], dim=1)
|
||||||
f += 1
|
f += 1
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,41 @@ class AutoWrappedModule(torch.nn.Module):
|
|||||||
return module(*args, **kwargs)
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class WanAutoCastLayerNorm(torch.nn.LayerNorm):
|
||||||
|
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||||
|
with init_weights_on_device(device=torch.device("meta")):
|
||||||
|
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||||
|
self.weight = module.weight
|
||||||
|
self.bias = module.bias
|
||||||
|
self.offload_dtype = offload_dtype
|
||||||
|
self.offload_device = offload_device
|
||||||
|
self.onload_dtype = onload_dtype
|
||||||
|
self.onload_device = onload_device
|
||||||
|
self.computation_dtype = computation_dtype
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
else:
|
||||||
|
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
|
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
|
with torch.amp.autocast(device_type=x.device.type):
|
||||||
|
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AutoWrappedLinear(torch.nn.Linear):
|
class AutoWrappedLinear(torch.nn.Linear):
|
||||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||||
with init_weights_on_device(device=torch.device("meta")):
|
with init_weights_on_device(device=torch.device("meta")):
|
||||||
|
|||||||
34
test.py
34
test.py
@@ -1,7 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
|
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||||
from diffsynth import ModelManager, save_video, VideoData, save_frames, save_video, download_models
|
from diffsynth import ModelManager, save_video, VideoData, save_frames, save_video, download_models
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, model_fn_wan_video
|
||||||
from diffsynth.controlnets.processors import Annotator
|
from diffsynth.controlnets.processors import Annotator
|
||||||
|
from diffsynth.data.video import crop_and_resize
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -13,28 +15,32 @@ pipe = WanVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
# ModelConfig("D:\projects\VideoX-Fun\models\Wan2.1-Fun-V1.1-1.3B-Control\diffusion_pytorch_model.safetensors", offload_device="cpu"),
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=10*10**9)
|
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9)
|
||||||
|
|
||||||
|
|
||||||
video = VideoData(rf"D:\pr_projects\20250503_dance\data\双马尾竖屏暴击!你的微笑就是彩虹的微笑♥ - 1.双马尾竖屏暴击!你的微笑就是彩虹的微笑♥(Av114086629088385,P1).mp4", height=832, width=480)
|
video = VideoData(rf"D:\pr_projects\20250503_dance\data\双马尾竖屏暴击!你的微笑就是彩虹的微笑♥ - 1.双马尾竖屏暴击!你的微笑就是彩虹的微笑♥(Av114086629088385,P1).mp4", height=832, width=480)
|
||||||
annotator = Annotator("openpose")
|
annotator = Annotator("openpose")
|
||||||
video = [video[i] for i in tqdm(range(450, 450+1*17, 1))]
|
video = [video[i] for i in tqdm(range(450, 450+1*81, 1))]
|
||||||
save_video(video, "video_input.mp4", fps=60, quality=5)
|
save_video(video, "video_input.mp4", fps=60, quality=5)
|
||||||
control_video = [annotator(f) for f in tqdm(video)]
|
control_video = [annotator(f) for f in tqdm(video)]
|
||||||
save_video(control_video, "video_control.mp4", fps=60, quality=5)
|
save_video(control_video, "video_control.mp4", fps=60, quality=5)
|
||||||
reference_image = Image.open(rf"D:\pr_projects\20250503_dance\data\marmot.png").resize((480, 832))
|
reference_image = crop_and_resize(Image.open(rf"D:\pr_projects\20250503_dance\data\marmot4.png"), 832, 480)
|
||||||
|
|
||||||
video = pipe(
|
with torch.amp.autocast("cuda", torch.bfloat16):
|
||||||
prompt="微距摄影风格特写画面,一只憨态可掬的土拨鼠正用后腿站立在碎石堆上,它在挥舞着双臂。金棕色的绒毛在阳光下泛着丝绸般的光泽,腹部毛发呈现浅杏色渐变,每根毛尖都闪烁着细密的光晕。两只黑曜石般的眼睛透出机警而温顺的光芒,鼻梁两侧的白色触须微微颤动,捕捉着空气中的气息。背景是虚化的灰绿色渐变,几簇嫩绿苔藓从画面右下角探出头来,与前景散落的鹅卵石形成微妙的景深对比。土拨鼠圆润的身形在逆光中勾勒出柔和的轮廓,耳朵紧贴头部的姿态流露出戒备中的天真,整个画面洋溢着自然界生灵特有的灵动与纯真。",
|
video = pipe(
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
prompt="微距摄影风格特写画面,一只憨态可掬的土拨鼠正用后腿站立在碎石堆上,它在挥舞着双臂。金棕色的绒毛在阳光下泛着丝绸般的光泽,腹部毛发呈现浅杏色渐变,每根毛尖都闪烁着细密的光晕。两只黑曜石般的眼睛透出机警而温顺的光芒,鼻梁两侧的白色触须微微颤动,捕捉着空气中的气息。背景是虚化的灰绿色渐变,几簇嫩绿苔藓从画面右下角探出头来,与前景散落的鹅卵石形成微妙的景深对比。土拨鼠圆润的身形在逆光中勾勒出柔和的轮廓,耳朵紧贴头部的姿态流露出戒备中的天真,整个画面洋溢着自然界生灵特有的灵动与纯真。",
|
||||||
seed=0, tiled=True,
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
height=832, width=480, num_frames=len(control_video),
|
seed=43, tiled=True,
|
||||||
control_video=control_video, reference_image=reference_image,
|
height=832, width=480, num_frames=len(control_video),
|
||||||
# num_inference_steps=30, cfg_scale=1,
|
control_video=control_video, reference_image=reference_image,
|
||||||
)
|
# sliding_window_size=5, sliding_window_stride=2,
|
||||||
save_video(video, "video1.mp4", fps=60, quality=5)
|
# num_inference_steps=100,
|
||||||
|
# cfg_merge=True,
|
||||||
|
sigma_shift=16,
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=60, quality=5)
|
||||||
|
|||||||
Reference in New Issue
Block a user