training framework

This commit is contained in:
Artiprocher
2025-05-12 17:48:28 +08:00
parent dbef6122e9
commit 675eefa07e
20 changed files with 939 additions and 174 deletions

View File

@@ -1,34 +1,26 @@
import torch, warnings, glob
import torch, warnings, glob, os
import numpy as np
from PIL import Image
from einops import repeat, reduce
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
import types
from ..models import ModelManager
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import WanPrompter
import torch, os
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..schedulers.flow_match import FlowMatchScheduler
from ..prompters import WanPrompter
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
@@ -50,6 +42,16 @@ class BasePipeline(torch.nn.Module):
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
@@ -135,8 +137,20 @@ class BasePipeline(torch.nn.Module):
def enable_cpu_offload(self):
warnings.warn("enable_cpu_offload is deprecated. This feature is automatically enabled if offload_device != device")
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
def get_free_vram(self):
total_memory = torch.cuda.get_device_properties(self.device).total_memory
allocated_memory = torch.cuda.device_memory_used(self.device)
return (total_memory - allocated_memory) / (1024 ** 3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name not in model_names:
model.eval()
model.requires_grad_(False)
@dataclass
@@ -146,17 +160,19 @@ class ModelConfig:
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
quantization_dtype: Optional[torch.dtype] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(self, local_model_path="./models", skip_download=False):
if self.path is None:
if self.model_id is None or self.origin_file_pattern is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
@@ -195,10 +211,36 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
]
self.model_fn = model_fn_wan_video
def train(self):
super().train()
self.scheduler.set_timesteps(1000, training=True)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None):
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_free_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
@@ -217,9 +259,11 @@ class WanVideoPipeline(BasePipeline):
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
@@ -233,7 +277,7 @@ class WanVideoPipeline(BasePipeline):
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
@@ -246,6 +290,7 @@ class WanVideoPipeline(BasePipeline):
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
@@ -304,6 +349,7 @@ class WanVideoPipeline(BasePipeline):
),
)
if self.vace is not None:
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.vace,
module_map = {
@@ -316,10 +362,11 @@ class WanVideoPipeline(BasePipeline):
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@@ -330,8 +377,23 @@ class WanVideoPipeline(BasePipeline):
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./models",
skip_download: bool = False
skip_download: bool = False,
redirect_common_files: bool = True,
):
# Redirect model path
if redirect_common_files:
redirect_dict = {
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
}
for model_config in model_configs:
if model_config.origin_file_pattern is None or model_config.model_id is None:
continue
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
@@ -339,7 +401,7 @@ class WanVideoPipeline(BasePipeline):
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.quantization_dtype or torch_dtype
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
@@ -356,63 +418,54 @@ class WanVideoPipeline(BasePipeline):
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
return pipe
def denoising_model(self):
return self.dit
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
@torch.no_grad()
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image=None,
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image=None,
end_image: Optional[Image.Image] = None,
# Video-to-video
input_video=None,
denoising_strength=1.0,
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# ControlNet
control_video=None,
reference_image=None,
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
# VACE
vace_video=None,
vace_video_mask=None,
vace_reference_image=None,
vace_scale=1.0,
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
# Randomness
seed=None,
rand_device="cpu",
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height=480,
width=832,
height: Optional[int] = 480,
width: Optional[int] = 832,
num_frames=81,
# Classifier-free guidance
cfg_scale=5.0,
cfg_merge=False,
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Scheduler
num_inference_steps=50,
sigma_shift=5.0,
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# Speed control
motion_bucket_id=None,
motion_bucket_id: Optional[int] = None,
# VAE tiling
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Sliding window
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
# Teacache
tea_cache_l1_thresh=None,
tea_cache_model_id="",
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
):
@@ -452,12 +505,12 @@ class WanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(**models, **inputs_shared, **inputs_posi, timestep=timestep)
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
if cfg_merge:
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
else:
noise_pred_nega = model_fn_wan_video(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -467,7 +520,7 @@ class WanVideoPipeline(BasePipeline):
# VACE (TODO: remove it)
if vace_reference_image is not None:
latents = latents[:, :, 1:]
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
# Decode
self.load_models_to_device(['vae'])
@@ -558,18 +611,21 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(input_video)
latents = pipe.encode_video(input_video, tiled, tile_size, tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
latents = pipe.scheduler.add_noise(latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
@@ -639,7 +695,7 @@ class WanVideoUnit_FunControl(PipelineUnit):
return {}
pipe.load_models_to_device(self.onload_model_names)
control_video = pipe.preprocess_video(control_video)
control_latents = pipe.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
@@ -678,7 +734,7 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
if motion_bucket_id is None:
return {}
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_bucket_id": motion_bucket_id}
@@ -703,18 +759,16 @@ class WanVideoUnit_VACE(PipelineUnit):
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
else:
vace_video = pipe.preprocess_video(vace_video)
vace_video = torch.stack(vace_video, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_mask is None:
vace_mask = torch.ones_like(vace_video)
else:
vace_mask = pipe.preprocess_video(vace_mask)
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
inactive = pipe.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
@@ -724,8 +778,7 @@ class WanVideoUnit_VACE(PipelineUnit):
pass
else:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = pipe.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
@@ -894,6 +947,7 @@ def model_fn_wan_video(
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -978,8 +1032,20 @@ def model_fn_wan_video(
if tea_cache_update:
x = tea_cache.update(x)
else:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(dit.blocks):
x = block(x, context, t_mod, freqs)
if use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
if tea_cache is not None:

View File

@@ -35,6 +35,9 @@ class FlowMatchScheduler():
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):

190
diffsynth/trainers/utils.py Normal file
View File

@@ -0,0 +1,190 @@
import imageio, os, torch, warnings, torchvision
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
from tqdm import tqdm
from accelerate import Accelerator
class VideoDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path, metadata_path,
frame_interval=1, num_frames=81,
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
data_file_keys=("video",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
repeat=1,
):
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
self.base_path = base_path
self.frame_interval = frame_interval
self.num_frames = num_frames
self.dynamic_resolution = dynamic_resolution
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.video_file_extension = video_file_extension
self.repeat = repeat
if height is not None and width is not None and dynamic_resolution == True:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames):
reader = imageio.get_reader(file_path)
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
frames.append(frame)
reader.close()
return frames
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_video(self, file_path):
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
return file_ext_name.lower() in self.image_file_extension
def is_video(self, file_path):
file_ext_name = file_path.split(".")[-1]
return file_ext_name.lower() in self.video_file_extension
def load_data(self, file_path):
if self.is_image(file_path):
return self.load_image(file_path)
elif self.is_video(file_path):
return self.load_video(file_path)
else:
return None
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
def to(self, *args, **kwargs):
for name, model in self.named_children():
model.to(*args, **kwargs)
return self
def trainable_modules(self):
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
return trainable_modules
def trainable_param_names(self):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
if lora_alpha is None:
lora_alpha = lora_rank
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
return model
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate, num_epochs, output_path, remove_prefix=None):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
accelerator = Accelerator(gradient_accumulation_steps=1)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
trainable_param_names = model.trainable_param_names()
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
if remove_prefix is not None:
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith(remove_prefix):
name = name[len(remove_prefix):]
state_dict_[name] = param
path = os.path.join(output_path, f"epoch-{epoch}")
accelerator.save(state_dict_, path, safe_serialization=True)

View File

@@ -8,8 +8,32 @@ def cast_to(weight, dtype, device):
return r
class AutoWrappedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
class AutoTorchModule(torch.nn.Module):
def __init__(self):
super().__init__()
def check_free_vram(self):
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
return used_memory < self.vram_limit
def offload(self):
if self.state != 0:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state != 1:
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def keep(self):
if self.state != 2:
self.to(dtype=self.computation_dtype, device=self.computation_device)
self.state = 2
class AutoWrappedModule(AutoTorchModule):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
@@ -18,28 +42,25 @@ class AutoWrappedModule(torch.nn.Module):
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
if self.state == 2:
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class WanAutoCastLayerNorm(torch.nn.LayerNorm):
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
@@ -50,31 +71,28 @@ class WanAutoCastLayerNorm(torch.nn.LayerNorm):
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
if self.state == 2:
weight, bias = self.weight, self.bias
else:
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
weight, bias = self.weight, self.bias
else:
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
with torch.amp.autocast(device_type=x.device.type):
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
return x
class AutoWrappedLinear(torch.nn.Linear):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
@@ -85,28 +103,25 @@ class AutoWrappedLinear(torch.nn.Linear):
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
if self.state == 2:
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None):
for name, module in model.named_children():
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
@@ -115,16 +130,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict,
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_)
module_ = target_module(module, **module_config_, vram_limit=vram_limit)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit)
return total_num_param
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
model.vram_management_enabled = True

View File

@@ -0,0 +1,34 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"),
],
)
pipe.enable_vram_management()
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=1, tiled=True,
motion_bucket_id=0
)
save_video(video, "video_slow.mp4", fps=15, quality=5)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=1, tiled=True,
motion_bucket_id=100
)
save_video(video, "video_fast.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,34 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video1.mp4", fps=15, quality=5)
# Video-to-video
video = VideoData("video1.mp4", height=480, width=832)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_video=video, denoising_strength=0.7,
seed=1, tiled=True
)
save_video(video, "video2.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,52 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
)
# Depth video -> Video
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_video=control_video,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)
# Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video2.mp4", fps=15, quality=5)
# Depth video + Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
vace_video=control_video,
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video3.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,36 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"]
)
# First and last frame to video
video = pipe(
prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)),
end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)),
seed=0, tiled=True,
height=960, width=960, num_frames=33,
sigma_shift=16,
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,34 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# Image-to-video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=image,
seed=0, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,35 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# Image-to-video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=image,
seed=0, tiled=True,
height=720, width=1280,
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,24 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
# Text-to-video
video = pipe(
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,36 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# First and last frame to video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=image,
seed=0, tiled=True
# You can input `end_image=xxx` to control the last frame of the video.
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,34 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/control_video.mp4"
)
# Control video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
control_video=control_video, height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,36 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# First and last frame to video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=image,
seed=0, tiled=True
# You can input `end_image=xxx` to control the last frame of the video.
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,34 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/control_video.mp4"
)
# Control video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
control_video=control_video, height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,36 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
)
# Control video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
control_video=control_video, reference_image=reference_image,
height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,36 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
)
# Control video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
control_video=control_video, reference_image=reference_image,
height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -1,7 +1,7 @@
torch>=2.0.0
torchvision
cupy-cuda12x
transformers==4.46.2
transformers
controlnet-aux==0.0.7
imageio
imageio[ffmpeg]
@@ -11,3 +11,4 @@ sentencepiece
protobuf
modelscope
ftfy
pynvml

46
test.py
View File

@@ -1,46 +0,0 @@
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)
from diffsynth import ModelManager, save_video, VideoData, save_frames, save_video, download_models
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, model_fn_wan_video
from diffsynth.controlnets.processors import Annotator
from diffsynth.data.video import crop_and_resize
from modelscope import snapshot_download
from tqdm import tqdm
from PIL import Image
# Load models
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
# ModelConfig("D:\projects\VideoX-Fun\models\Wan2.1-Fun-V1.1-1.3B-Control\diffusion_pytorch_model.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9)
video = VideoData(rf"D:\pr_projects\20250503_dance\data\双马尾竖屏暴击!你的微笑就是彩虹的微笑♥ - 1.双马尾竖屏暴击!你的微笑就是彩虹的微笑♥(Av114086629088385,P1).mp4", height=832, width=480)
annotator = Annotator("openpose")
video = [video[i] for i in tqdm(range(450, 450+1*81, 1))]
save_video(video, "video_input.mp4", fps=60, quality=5)
control_video = [annotator(f) for f in tqdm(video)]
save_video(control_video, "video_control.mp4", fps=60, quality=5)
reference_image = crop_and_resize(Image.open(rf"D:\pr_projects\20250503_dance\data\marmot4.png"), 832, 480)
with torch.amp.autocast("cuda", torch.bfloat16):
video = pipe(
prompt="微距摄影风格特写画面,一只憨态可掬的土拨鼠正用后腿站立在碎石堆上,它在挥舞着双臂。金棕色的绒毛在阳光下泛着丝绸般的光泽,腹部毛发呈现浅杏色渐变,每根毛尖都闪烁着细密的光晕。两只黑曜石般的眼睛透出机警而温顺的光芒,鼻梁两侧的白色触须微微颤动,捕捉着空气中的气息。背景是虚化的灰绿色渐变,几簇嫩绿苔藓从画面右下角探出头来,与前景散落的鹅卵石形成微妙的景深对比。土拨鼠圆润的身形在逆光中勾勒出柔和的轮廓,耳朵紧贴头部的姿态流露出戒备中的天真,整个画面洋溢着自然界生灵特有的灵动与纯真。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=43, tiled=True,
height=832, width=480, num_frames=len(control_video),
control_video=control_video, reference_image=reference_image,
# sliding_window_size=5, sliding_window_stride=2,
# num_inference_steps=100,
# cfg_merge=True,
sigma_shift=16,
)
save_video(video, "video1.mp4", fps=60, quality=5)

75
train.py Normal file
View File

@@ -0,0 +1,75 @@
import torch, os
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class WanTrainingModule(DiffusionTrainingModule):
def __init__(self, model_paths):
super().__init__()
self.pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cpu",
model_configs=[ModelConfig(path=path) for path in model_paths],
)
self.pipe.freeze_except([])
self.pipe.dit = self.add_lora_to_model(self.pipe.dit, target_modules="q,k,v,o,ffn.0,ffn.2".split(","), lora_alpha=16)
def forward_preprocess(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
inputs_shared = {
"input_video": data["video"],
"height": data["video"][0].size[1],
"width": data["video"][0].size[0],
"num_frames": len(data["video"]),
# Please do not modify the following parameters.
"cfg_scale": 1,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": True,
"cfg_merge": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data):
inputs = self.forward_preprocess(data)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
return loss
def add_general_parsers(parser):
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.")
parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.")
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.")
parser.add_argument("--model_paths", type=str, default="", help="Model paths to be loaded. Separated by commas.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
return parser
if __name__ == "__main__":
dataset = VideoDataset(
base_path="data/pixabay100/train",
metadata_path="data/pixabay100/metadata_example.csv",
height=480, width=832,
data_file_keys=["video"],
repeat=400,
)
model = WanTrainingModule([
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
])
launch_training_task(model, dataset)