mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
training framework
This commit is contained in:
@@ -1,34 +1,26 @@
|
||||
import torch, warnings, glob
|
||||
import torch, warnings, glob, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat, reduce
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
import types
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import WanPrompter
|
||||
import torch, os
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +42,16 @@ class BasePipeline(torch.nn.Module):
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.vram_management_enabled = False
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
@@ -135,8 +137,20 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("enable_cpu_offload is deprecated. This feature is automatically enabled if offload_device != device")
|
||||
|
||||
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
|
||||
|
||||
|
||||
def get_free_vram(self):
|
||||
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
allocated_memory = torch.cuda.device_memory_used(self.device)
|
||||
return (total_memory - allocated_memory) / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -146,17 +160,19 @@ class ModelConfig:
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_resource: str = "ModelScope"
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
quantization_dtype: Optional[torch.dtype] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
|
||||
def download_if_necessary(self, local_model_path="./models", skip_download=False):
|
||||
if self.path is None:
|
||||
if self.model_id is None or self.origin_file_pattern is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
if not skip_download:
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(local_model_path, self.model_id),
|
||||
allow_file_pattern=self.origin_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
||||
@@ -195,10 +211,36 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
def train(self):
|
||||
super().train()
|
||||
self.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
||||
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_free_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if self.text_encoder is not None:
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
@@ -217,9 +259,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit is not None:
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
@@ -233,7 +277,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
@@ -246,6 +290,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
@@ -304,6 +349,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
)
|
||||
if self.vace is not None:
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.vace,
|
||||
module_map = {
|
||||
@@ -316,10 +362,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -330,8 +377,23 @@ class WanVideoPipeline(BasePipeline):
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
local_model_path: str = "./models",
|
||||
skip_download: bool = False
|
||||
skip_download: bool = False,
|
||||
redirect_common_files: bool = True,
|
||||
):
|
||||
# Redirect model path
|
||||
if redirect_common_files:
|
||||
redirect_dict = {
|
||||
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
||||
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
||||
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
|
||||
}
|
||||
for model_config in model_configs:
|
||||
if model_config.origin_file_pattern is None or model_config.model_id is None:
|
||||
continue
|
||||
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
|
||||
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
|
||||
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
|
||||
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
for model_config in model_configs:
|
||||
@@ -339,7 +401,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
model_manager.load_model(
|
||||
model_config.path,
|
||||
device=model_config.offload_device or device,
|
||||
torch_dtype=model_config.quantization_dtype or torch_dtype
|
||||
torch_dtype=model_config.offload_dtype or torch_dtype
|
||||
)
|
||||
|
||||
# Initialize pipeline
|
||||
@@ -356,63 +418,54 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||
return pipe
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
input_image=None,
|
||||
input_image: Optional[Image.Image] = None,
|
||||
# First-last-frame-to-video
|
||||
end_image=None,
|
||||
end_image: Optional[Image.Image] = None,
|
||||
# Video-to-video
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
input_video: Optional[list[Image.Image]] = None,
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# ControlNet
|
||||
control_video=None,
|
||||
reference_image=None,
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
# VACE
|
||||
vace_video=None,
|
||||
vace_video_mask=None,
|
||||
vace_reference_image=None,
|
||||
vace_scale=1.0,
|
||||
vace_video: Optional[list[Image.Image]] = None,
|
||||
vace_video_mask: Optional[Image.Image] = None,
|
||||
vace_reference_image: Optional[Image.Image] = None,
|
||||
vace_scale: Optional[float] = 1.0,
|
||||
# Randomness
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height=480,
|
||||
width=832,
|
||||
height: Optional[int] = 480,
|
||||
width: Optional[int] = 832,
|
||||
num_frames=81,
|
||||
# Classifier-free guidance
|
||||
cfg_scale=5.0,
|
||||
cfg_merge=False,
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Scheduler
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
# Speed control
|
||||
motion_bucket_id=None,
|
||||
motion_bucket_id: Optional[int] = None,
|
||||
# VAE tiling
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||
# Sliding window
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
# Teacache
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="",
|
||||
tea_cache_l1_thresh: Optional[float] = None,
|
||||
tea_cache_model_id: Optional[str] = "",
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
@@ -452,12 +505,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
if cfg_scale != 1.0:
|
||||
if cfg_merge:
|
||||
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
|
||||
else:
|
||||
noise_pred_nega = model_fn_wan_video(**models, **inputs_shared, **inputs_nega, timestep=timestep)
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -467,7 +520,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
latents = latents[:, :, 1:]
|
||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -558,18 +611,21 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride):
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
|
||||
if input_video is None:
|
||||
return {"latents": noise}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
latents = pipe.encode_video(input_video, tiled, tile_size, tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
latents = pipe.scheduler.add_noise(latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
|
||||
@@ -639,7 +695,7 @@ class WanVideoUnit_FunControl(PipelineUnit):
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
control_video = pipe.preprocess_video(control_video)
|
||||
control_latents = pipe.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if clip_feature is None or y is None:
|
||||
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
@@ -678,7 +734,7 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
|
||||
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
|
||||
if motion_bucket_id is None:
|
||||
return {}
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
@@ -703,18 +759,16 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
else:
|
||||
vace_video = pipe.preprocess_video(vace_video)
|
||||
vace_video = torch.stack(vace_video, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
if vace_mask is None:
|
||||
vace_mask = torch.ones_like(vace_video)
|
||||
else:
|
||||
vace_mask = pipe.preprocess_video(vace_mask)
|
||||
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
||||
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
||||
inactive = pipe.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
reactive = pipe.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||
|
||||
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||
@@ -724,8 +778,7 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
pass
|
||||
else:
|
||||
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = pipe.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||
@@ -894,6 +947,7 @@ def model_fn_wan_video(
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -978,8 +1032,20 @@ def model_fn_wan_video(
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||
if tea_cache is not None:
|
||||
|
||||
@@ -35,6 +35,9 @@ class FlowMatchScheduler():
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
|
||||
190
diffsynth/trainers/utils.py
Normal file
190
diffsynth/trainers/utils.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import imageio, os, torch, warnings, torchvision
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
|
||||
class VideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path, metadata_path,
|
||||
frame_interval=1, num_frames=81,
|
||||
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
data_file_keys=("video",),
|
||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
|
||||
repeat=1,
|
||||
):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
self.base_path = base_path
|
||||
self.frame_interval = frame_interval
|
||||
self.num_frames = num_frames
|
||||
self.dynamic_resolution = dynamic_resolution
|
||||
self.max_pixels = max_pixels
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.data_file_keys = data_file_keys
|
||||
self.image_file_extension = image_file_extension
|
||||
self.video_file_extension = video_file_extension
|
||||
self.repeat = repeat
|
||||
|
||||
if height is not None and width is not None and dynamic_resolution == True:
|
||||
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||
self.dynamic_resolution = False
|
||||
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.dynamic_resolution:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
|
||||
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames):
|
||||
reader = imageio.get_reader(file_path)
|
||||
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||
reader.close()
|
||||
return None
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
return frames
|
||||
|
||||
|
||||
def load_image(self, file_path):
|
||||
image = Image.open(file_path).convert("RGB")
|
||||
image = self.crop_and_resize(image, *self.get_height_width(image))
|
||||
return image
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
|
||||
return frames
|
||||
|
||||
|
||||
def is_image(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
return file_ext_name.lower() in self.image_file_extension
|
||||
|
||||
|
||||
def is_video(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
return file_ext_name.lower() in self.video_file_extension
|
||||
|
||||
|
||||
def load_data(self, file_path):
|
||||
if self.is_image(file_path):
|
||||
return self.load_image(file_path)
|
||||
elif self.is_video(file_path):
|
||||
return self.load_video(file_path)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if data[key] is None:
|
||||
warnings.warn(f"cannot load file {data[key]}.")
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
|
||||
|
||||
class DiffusionTrainingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for name, model in self.named_children():
|
||||
model.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||
return trainable_modules
|
||||
|
||||
|
||||
def trainable_param_names(self):
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
return trainable_param_names
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate, num_epochs, output_path, remove_prefix=None):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=1)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
trainable_param_names = model.trainable_param_names()
|
||||
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||
if remove_prefix is not None:
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith(remove_prefix):
|
||||
name = name[len(remove_prefix):]
|
||||
state_dict_[name] = param
|
||||
path = os.path.join(output_path, f"epoch-{epoch}")
|
||||
accelerator.save(state_dict_, path, safe_serialization=True)
|
||||
@@ -8,8 +8,32 @@ def cast_to(weight, dtype, device):
|
||||
return r
|
||||
|
||||
|
||||
class AutoWrappedModule(torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def check_free_vram(self):
|
||||
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
def offload(self):
|
||||
if self.state != 0:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state != 1:
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def keep(self):
|
||||
if self.state != 2:
|
||||
self.to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
self.state = 2
|
||||
|
||||
|
||||
class AutoWrappedModule(AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
|
||||
super().__init__()
|
||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||
self.offload_dtype = offload_dtype
|
||||
@@ -18,28 +42,25 @@ class AutoWrappedModule(torch.nn.Module):
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
if self.state == 2:
|
||||
module = self.module
|
||||
else:
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
module = self.module
|
||||
elif self.vram_limit is not None and self.check_free_vram():
|
||||
self.keep()
|
||||
module = self.module
|
||||
else:
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
class WanAutoCastLayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
@@ -50,31 +71,28 @@ class WanAutoCastLayerNorm(torch.nn.LayerNorm):
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
if self.state == 2:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
elif self.vram_limit is not None and self.check_free_vram():
|
||||
self.keep()
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
with torch.amp.autocast(device_type=x.device.type):
|
||||
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
|
||||
return x
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
@@ -85,28 +103,25 @@ class AutoWrappedLinear(torch.nn.Linear):
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
if self.state == 2:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
elif self.vram_limit is not None and self.check_free_vram():
|
||||
self.keep()
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None):
|
||||
for name, module in model.named_children():
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
@@ -115,16 +130,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict,
|
||||
module_config_ = overflow_module_config
|
||||
else:
|
||||
module_config_ = module_config
|
||||
module_ = target_module(module, **module_config_)
|
||||
module_ = target_module(module, **module_config_, vram_limit=vram_limit)
|
||||
setattr(model, name, module_)
|
||||
total_num_param += num_param
|
||||
break
|
||||
else:
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit)
|
||||
return total_num_param
|
||||
|
||||
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
|
||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
|
||||
model.vram_management_enabled = True
|
||||
|
||||
|
||||
34
examples/wanvideo/model_inference/wan_1.3b_speed_control.py
Normal file
34
examples/wanvideo/model_inference/wan_1.3b_speed_control.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=0
|
||||
)
|
||||
save_video(video, "video_slow.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=100
|
||||
)
|
||||
save_video(video, "video_fast.mp4", fps=15, quality=5)
|
||||
34
examples/wanvideo/model_inference/wan_1.3b_text_to_video.py
Normal file
34
examples/wanvideo/model_inference/wan_1.3b_text_to_video.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Video-to-video
|
||||
video = VideoData("video1.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_video=video, denoising_strength=0.7,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
52
examples/wanvideo/model_inference/wan_1.3b_vace.py
Normal file
52
examples/wanvideo/model_inference/wan_1.3b_vace.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
|
||||
# Depth video -> Video
|
||||
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
|
||||
# Depth video + Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video3.mp4", fps=15, quality=5)
|
||||
36
examples/wanvideo/model_inference/wan_14b_flf2v.py
Normal file
36
examples/wanvideo/model_inference/wan_14b_flf2v.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"]
|
||||
)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)),
|
||||
end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)),
|
||||
seed=0, tiled=True,
|
||||
height=960, width=960, num_frames=33,
|
||||
sigma_shift=16,
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# Image-to-video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# Image-to-video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True,
|
||||
height=720, width=1280,
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
24
examples/wanvideo/model_inference/wan_14b_text_to_video.py
Normal file
24
examples/wanvideo/model_inference/wan_14b_text_to_video.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
36
examples/wanvideo/model_inference/wan_fun_1.3b_InP.py
Normal file
36
examples/wanvideo/model_inference/wan_fun_1.3b_InP.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
34
examples/wanvideo/model_inference/wan_fun_1.3b_control.py
Normal file
34
examples/wanvideo/model_inference/wan_fun_1.3b_control.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
36
examples/wanvideo/model_inference/wan_fun_14b_InP.py
Normal file
36
examples/wanvideo/model_inference/wan_fun_14b_InP.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
34
examples/wanvideo/model_inference/wan_fun_14b_control.py
Normal file
34
examples/wanvideo/model_inference/wan_fun_14b_control.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
@@ -1,7 +1,7 @@
|
||||
torch>=2.0.0
|
||||
torchvision
|
||||
cupy-cuda12x
|
||||
transformers==4.46.2
|
||||
transformers
|
||||
controlnet-aux==0.0.7
|
||||
imageio
|
||||
imageio[ffmpeg]
|
||||
@@ -11,3 +11,4 @@ sentencepiece
|
||||
protobuf
|
||||
modelscope
|
||||
ftfy
|
||||
pynvml
|
||||
|
||||
46
test.py
46
test.py
@@ -1,46 +0,0 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||
from diffsynth import ModelManager, save_video, VideoData, save_frames, save_video, download_models
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, model_fn_wan_video
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
from diffsynth.data.video import crop_and_resize
|
||||
from modelscope import snapshot_download
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Load models
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
# ModelConfig("D:\projects\VideoX-Fun\models\Wan2.1-Fun-V1.1-1.3B-Control\diffusion_pytorch_model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9)
|
||||
|
||||
video = VideoData(rf"D:\pr_projects\20250503_dance\data\双马尾竖屏暴击!你的微笑就是彩虹的微笑♥ - 1.双马尾竖屏暴击!你的微笑就是彩虹的微笑♥(Av114086629088385,P1).mp4", height=832, width=480)
|
||||
annotator = Annotator("openpose")
|
||||
video = [video[i] for i in tqdm(range(450, 450+1*81, 1))]
|
||||
save_video(video, "video_input.mp4", fps=60, quality=5)
|
||||
control_video = [annotator(f) for f in tqdm(video)]
|
||||
save_video(control_video, "video_control.mp4", fps=60, quality=5)
|
||||
reference_image = crop_and_resize(Image.open(rf"D:\pr_projects\20250503_dance\data\marmot4.png"), 832, 480)
|
||||
|
||||
with torch.amp.autocast("cuda", torch.bfloat16):
|
||||
video = pipe(
|
||||
prompt="微距摄影风格特写画面,一只憨态可掬的土拨鼠正用后腿站立在碎石堆上,它在挥舞着双臂。金棕色的绒毛在阳光下泛着丝绸般的光泽,腹部毛发呈现浅杏色渐变,每根毛尖都闪烁着细密的光晕。两只黑曜石般的眼睛透出机警而温顺的光芒,鼻梁两侧的白色触须微微颤动,捕捉着空气中的气息。背景是虚化的灰绿色渐变,几簇嫩绿苔藓从画面右下角探出头来,与前景散落的鹅卵石形成微妙的景深对比。土拨鼠圆润的身形在逆光中勾勒出柔和的轮廓,耳朵紧贴头部的姿态流露出戒备中的天真,整个画面洋溢着自然界生灵特有的灵动与纯真。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=43, tiled=True,
|
||||
height=832, width=480, num_frames=len(control_video),
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
# sliding_window_size=5, sliding_window_stride=2,
|
||||
# num_inference_steps=100,
|
||||
# cfg_merge=True,
|
||||
sigma_shift=16,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=60, quality=5)
|
||||
75
train.py
Normal file
75
train.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch, os
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
|
||||
class WanTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, model_paths):
|
||||
super().__init__()
|
||||
self.pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu",
|
||||
model_configs=[ModelConfig(path=path) for path in model_paths],
|
||||
)
|
||||
self.pipe.freeze_except([])
|
||||
self.pipe.dit = self.add_lora_to_model(self.pipe.dit, target_modules="q,k,v,o,ffn.0,ffn.2".split(","), lora_alpha=16)
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {}
|
||||
inputs_shared = {
|
||||
"input_video": data["video"],
|
||||
"height": data["video"][0].size[1],
|
||||
"width": data["video"][0].size[0],
|
||||
"num_frames": len(data["video"]),
|
||||
# Please do not modify the following parameters.
|
||||
"cfg_scale": 1,
|
||||
"tiled": False,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": True,
|
||||
"cfg_merge": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data):
|
||||
inputs = self.forward_preprocess(data)
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def add_general_parsers(parser):
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.")
|
||||
parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default="", help="Model paths to be loaded. Separated by commas.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = VideoDataset(
|
||||
base_path="data/pixabay100/train",
|
||||
metadata_path="data/pixabay100/metadata_example.csv",
|
||||
height=480, width=832,
|
||||
data_file_keys=["video"],
|
||||
repeat=400,
|
||||
)
|
||||
model = WanTrainingModule([
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
])
|
||||
launch_training_task(model, dataset)
|
||||
|
||||
Reference in New Issue
Block a user