From 675eefa07e28f3ab16121d59b2dcead1c67b879e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 12 May 2025 17:48:28 +0800 Subject: [PATCH] training framework --- diffsynth/pipelines/wan_video_new.py | 222 ++++++++++++------ diffsynth/schedulers/flow_match.py | 3 + diffsynth/trainers/utils.py | 190 +++++++++++++++ diffsynth/vram_management/layers.py | 113 +++++---- .../model_inference/wan_1.3b_speed_control.py | 34 +++ .../model_inference/wan_1.3b_text_to_video.py | 34 +++ .../wanvideo/model_inference/wan_1.3b_vace.py | 52 ++++ .../wanvideo/model_inference/wan_14b_flf2v.py | 36 +++ .../wan_14b_image_to_video_480p.py | 34 +++ .../wan_14b_image_to_video_720p.py | 35 +++ .../model_inference/wan_14b_text_to_video.py | 24 ++ .../model_inference/wan_fun_1.3b_InP.py | 36 +++ .../model_inference/wan_fun_1.3b_control.py | 34 +++ .../model_inference/wan_fun_14b_InP.py | 36 +++ .../model_inference/wan_fun_14b_control.py | 34 +++ .../wan_fun_v1.1_1.3b_reference_control.py | 36 +++ .../wan_fun_v1.1_14b_reference_control.py | 36 +++ requirements.txt | 3 +- test.py | 46 ---- train.py | 75 ++++++ 20 files changed, 939 insertions(+), 174 deletions(-) create mode 100644 diffsynth/trainers/utils.py create mode 100644 examples/wanvideo/model_inference/wan_1.3b_speed_control.py create mode 100644 examples/wanvideo/model_inference/wan_1.3b_text_to_video.py create mode 100644 examples/wanvideo/model_inference/wan_1.3b_vace.py create mode 100644 examples/wanvideo/model_inference/wan_14b_flf2v.py create mode 100644 examples/wanvideo/model_inference/wan_14b_image_to_video_480p.py create mode 100644 examples/wanvideo/model_inference/wan_14b_image_to_video_720p.py create mode 100644 examples/wanvideo/model_inference/wan_14b_text_to_video.py create mode 100644 examples/wanvideo/model_inference/wan_fun_1.3b_InP.py create mode 100644 examples/wanvideo/model_inference/wan_fun_1.3b_control.py create mode 100644 examples/wanvideo/model_inference/wan_fun_14b_InP.py create mode 100644 examples/wanvideo/model_inference/wan_fun_14b_control.py create mode 100644 examples/wanvideo/model_inference/wan_fun_v1.1_1.3b_reference_control.py create mode 100644 examples/wanvideo/model_inference/wan_fun_v1.1_14b_reference_control.py delete mode 100644 test.py create mode 100644 train.py diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index de05e50..6110cc2 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1,34 +1,26 @@ -import torch, warnings, glob +import torch, warnings, glob, os import numpy as np from PIL import Image from einops import repeat, reduce from typing import Optional, Union from dataclasses import dataclass from modelscope import snapshot_download - - -import types -from ..models import ModelManager -from ..models.wan_video_dit import WanModel -from ..models.wan_video_text_encoder import WanTextEncoder -from ..models.wan_video_vae import WanVideoVAE -from ..models.wan_video_image_encoder import WanImageEncoder -from ..models.wan_video_vace import VaceWanModel -from ..schedulers.flow_match import FlowMatchScheduler -from .base import BasePipeline -from ..prompters import WanPrompter -import torch, os from einops import rearrange import numpy as np from PIL import Image from tqdm import tqdm from typing import Optional -from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm -from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm -from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d -from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample +from ..models import ModelManager +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..schedulers.flow_match import FlowMatchScheduler +from ..prompters import WanPrompter +from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm @@ -50,6 +42,16 @@ class BasePipeline(torch.nn.Module): self.time_division_factor = time_division_factor self.time_division_remainder = time_division_remainder self.vram_management_enabled = False + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self def check_resize_height_width(self, height, width, num_frames=None): @@ -135,8 +137,20 @@ class BasePipeline(torch.nn.Module): def enable_cpu_offload(self): - warnings.warn("enable_cpu_offload is deprecated. This feature is automatically enabled if offload_device != device") - + warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.") + + + def get_free_vram(self): + total_memory = torch.cuda.get_device_properties(self.device).total_memory + allocated_memory = torch.cuda.device_memory_used(self.device) + return (total_memory - allocated_memory) / (1024 ** 3) + + + def freeze_except(self, model_names): + for name, model in self.named_children(): + if name not in model_names: + model.eval() + model.requires_grad_(False) @dataclass @@ -146,17 +160,19 @@ class ModelConfig: origin_file_pattern: Union[str, list[str]] = None download_resource: str = "ModelScope" offload_device: Optional[Union[str, torch.device]] = None - quantization_dtype: Optional[torch.dtype] = None + offload_dtype: Optional[torch.dtype] = None def download_if_necessary(self, local_model_path="./models", skip_download=False): if self.path is None: if self.model_id is None or self.origin_file_pattern is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") if not skip_download: + downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) snapshot_download( self.model_id, local_dir=os.path.join(local_model_path, self.model_id), allow_file_pattern=self.origin_file_pattern, + ignore_file_pattern=downloaded_files, local_files_only=False ) self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) @@ -195,10 +211,36 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_TeaCache(), WanVideoUnit_CfgMerger(), ] + self.model_fn = model_fn_wan_video + + + def train(self): + super().train() + self.scheduler.set_timesteps(1000, training=True) + + + def training_loss(self, **inputs): + timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) + + inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) + training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) + + noise_pred = self.model_fn(**inputs, timestep=timestep) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.scheduler.training_weight(timestep) + return loss - def enable_vram_management(self, num_persistent_param_in_dit=None): + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): self.vram_management_enabled = True + if num_persistent_param_in_dit is not None: + vram_limit = None + else: + if vram_limit is None: + vram_limit = self.get_free_vram() + vram_limit = vram_limit - vram_buffer if self.text_encoder is not None: dtype = next(iter(self.text_encoder.parameters())).dtype enable_vram_management( @@ -217,9 +259,11 @@ class WanVideoPipeline(BasePipeline): computation_dtype=self.torch_dtype, computation_device=self.device, ), + vram_limit=vram_limit, ) if self.dit is not None: dtype = next(iter(self.dit.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device enable_vram_management( self.dit, module_map = { @@ -233,7 +277,7 @@ class WanVideoPipeline(BasePipeline): offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, - onload_device=self.device, + onload_device=device, computation_dtype=self.torch_dtype, computation_device=self.device, ), @@ -246,6 +290,7 @@ class WanVideoPipeline(BasePipeline): computation_dtype=self.torch_dtype, computation_device=self.device, ), + vram_limit=vram_limit, ) if self.vae is not None: dtype = next(iter(self.vae.parameters())).dtype @@ -304,6 +349,7 @@ class WanVideoPipeline(BasePipeline): ), ) if self.vace is not None: + device = "cpu" if vram_limit is not None else self.device enable_vram_management( self.vace, module_map = { @@ -316,10 +362,11 @@ class WanVideoPipeline(BasePipeline): offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, - onload_device=self.device, + onload_device=device, computation_dtype=self.torch_dtype, computation_device=self.device, ), + vram_limit=vram_limit, ) @@ -330,8 +377,23 @@ class WanVideoPipeline(BasePipeline): model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), local_model_path: str = "./models", - skip_download: bool = False + skip_download: bool = False, + redirect_common_files: bool = True, ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", + "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern] + # Download and load models model_manager = ModelManager() for model_config in model_configs: @@ -339,7 +401,7 @@ class WanVideoPipeline(BasePipeline): model_manager.load_model( model_config.path, device=model_config.offload_device or device, - torch_dtype=model_config.quantization_dtype or torch_dtype + torch_dtype=model_config.offload_dtype or torch_dtype ) # Initialize pipeline @@ -356,63 +418,54 @@ class WanVideoPipeline(BasePipeline): pipe.prompter.fetch_models(pipe.text_encoder) pipe.prompter.fetch_tokenizer(tokenizer_config.path) return pipe - - - def denoising_model(self): - return self.dit - - - def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): - latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - return latents @torch.no_grad() def __call__( self, # Prompt - prompt, - negative_prompt="", + prompt: str, + negative_prompt: Optional[str] = "", # Image-to-video - input_image=None, + input_image: Optional[Image.Image] = None, # First-last-frame-to-video - end_image=None, + end_image: Optional[Image.Image] = None, # Video-to-video - input_video=None, - denoising_strength=1.0, + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, # ControlNet - control_video=None, - reference_image=None, + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, # VACE - vace_video=None, - vace_video_mask=None, - vace_reference_image=None, - vace_scale=1.0, + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, # Randomness - seed=None, - rand_device="cpu", + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", # Shape - height=480, - width=832, + height: Optional[int] = 480, + width: Optional[int] = 832, num_frames=81, # Classifier-free guidance - cfg_scale=5.0, - cfg_merge=False, + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, # Scheduler - num_inference_steps=50, - sigma_shift=5.0, + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, # Speed control - motion_bucket_id=None, + motion_bucket_id: Optional[int] = None, # VAE tiling - tiled=True, - tile_size=(30, 52), - tile_stride=(15, 26), + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), # Sliding window sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, # Teacache - tea_cache_l1_thresh=None, - tea_cache_model_id="", + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", # progress_bar progress_bar_cmd=tqdm, ): @@ -452,12 +505,12 @@ class WanVideoPipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference - noise_pred_posi = model_fn_wan_video(**models, **inputs_shared, **inputs_posi, timestep=timestep) + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: if cfg_merge: noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) else: - noise_pred_nega = model_fn_wan_video(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -467,7 +520,7 @@ class WanVideoPipeline(BasePipeline): # VACE (TODO: remove it) if vace_reference_image is not None: - latents = latents[:, :, 1:] + inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] # Decode self.load_models_to_device(['vae']) @@ -558,18 +611,21 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit): class WanVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride"), + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride): + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength): if input_video is None: return {"latents": noise} pipe.load_models_to_device(["vae"]) input_video = pipe.preprocess_video(input_video) - latents = pipe.encode_video(input_video, tiled, tile_size, tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - latents = pipe.scheduler.add_noise(latents, noise, timestep=pipe.scheduler.timesteps[0]) - return {"latents": latents} + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} @@ -639,7 +695,7 @@ class WanVideoUnit_FunControl(PipelineUnit): return {} pipe.load_models_to_device(self.onload_model_names) control_video = pipe.preprocess_video(control_video) - control_latents = pipe.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) if clip_feature is None or y is None: clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) @@ -678,7 +734,7 @@ class WanVideoUnit_SpeedControl(PipelineUnit): def process(self, pipe: WanVideoPipeline, motion_bucket_id): if motion_bucket_id is None: return {} - motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device) + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) return {"motion_bucket_id": motion_bucket_id} @@ -703,18 +759,16 @@ class WanVideoUnit_VACE(PipelineUnit): vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) else: vace_video = pipe.preprocess_video(vace_video) - vace_video = torch.stack(vace_video, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device) if vace_mask is None: vace_mask = torch.ones_like(vace_video) else: vace_mask = pipe.preprocess_video(vace_mask) - vace_mask = torch.stack(vace_mask, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device) inactive = vace_video * (1 - vace_mask) + 0 * vace_mask reactive = vace_video * vace_mask + 0 * (1 - vace_mask) - inactive = pipe.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - reactive = pipe.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_video_latents = torch.concat((inactive, reactive), dim=1) vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) @@ -724,8 +778,7 @@ class WanVideoUnit_VACE(PipelineUnit): pass else: vace_reference_image = pipe.preprocess_video([vace_reference_image]) - vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device) - vace_reference_latents = pipe.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) @@ -894,6 +947,7 @@ def model_fn_wan_video( sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -978,8 +1032,20 @@ def model_fn_wan_video( if tea_cache_update: x = tea_cache.update(x) else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + for block_id, block in enumerate(dit.blocks): - x = block(x, context, t_mod, freqs) + if use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale if tea_cache is not None: diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index d6d0219..9754b98 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -35,6 +35,9 @@ class FlowMatchScheduler(): y_shifted = y - y.min() bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing + self.training = True + else: + self.training = False def step(self, model_output, timestep, sample, to_final=False, **kwargs): diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py new file mode 100644 index 0000000..7bc0f15 --- /dev/null +++ b/diffsynth/trainers/utils.py @@ -0,0 +1,190 @@ +import imageio, os, torch, warnings, torchvision +from peft import LoraConfig, inject_adapter_in_model +from PIL import Image +import pandas as pd +from tqdm import tqdm +from accelerate import Accelerator + + + +class VideoDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path, metadata_path, + frame_interval=1, num_frames=81, + dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + data_file_keys=("video",), + image_file_extension=("jpg", "jpeg", "png", "webp"), + video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), + repeat=1, + ): + metadata = pd.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + self.base_path = base_path + self.frame_interval = frame_interval + self.num_frames = num_frames + self.dynamic_resolution = dynamic_resolution + self.max_pixels = max_pixels + self.height = height + self.width = width + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.data_file_keys = data_file_keys + self.image_file_extension = image_file_extension + self.video_file_extension = video_file_extension + self.repeat = repeat + + if height is not None and width is not None and dynamic_resolution == True: + print("Height and width are fixed. Setting `dynamic_resolution` to False.") + self.dynamic_resolution = False + + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + + def get_height_width(self, image): + if self.dynamic_resolution: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + + def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames): + reader = imageio.get_reader(file_path) + if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame, *self.get_height_width(frame)) + frames.append(frame) + reader.close() + return frames + + + def load_image(self, file_path): + image = Image.open(file_path).convert("RGB") + image = self.crop_and_resize(image, *self.get_height_width(image)) + return image + + + def load_video(self, file_path): + frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + return file_ext_name.lower() in self.image_file_extension + + + def is_video(self, file_path): + file_ext_name = file_path.split(".")[-1] + return file_ext_name.lower() in self.video_file_extension + + + def load_data(self, file_path): + if self.is_image(file_path): + return self.load_image(file_path) + elif self.is_video(file_path): + return self.load_video(file_path) + else: + return None + + + def __getitem__(self, data_id): + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + path = os.path.join(self.base_path, data[key]) + data[key] = self.load_data(path) + if data[key] is None: + warnings.warn(f"cannot load file {data[key]}.") + return None + return data + + + def __len__(self): + return len(self.data) * self.repeat + + + +class DiffusionTrainingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + + def to(self, *args, **kwargs): + for name, model in self.named_children(): + model.to(*args, **kwargs) + return self + + + def trainable_modules(self): + trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) + return trainable_modules + + + def trainable_param_names(self): + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + return trainable_param_names + + + def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): + if lora_alpha is None: + lora_alpha = lora_rank + lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) + model = inject_adapter_in_model(lora_config, model) + return model + + + +def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate, num_epochs, output_path, remove_prefix=None): + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) + + accelerator = Accelerator(gradient_accumulation_steps=1) + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + + for epoch in range(num_epochs): + for data in tqdm(dataloader): + with accelerator.accumulate(model): + optimizer.zero_grad() + loss = model(data) + accelerator.backward(loss) + optimizer.step() + scheduler.step() + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + trainable_param_names = model.trainable_param_names() + state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} + if remove_prefix is not None: + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith(remove_prefix): + name = name[len(remove_prefix):] + state_dict_[name] = param + path = os.path.join(output_path, f"epoch-{epoch}") + accelerator.save(state_dict_, path, safe_serialization=True) diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index aa2bda2..45e7433 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -8,8 +8,32 @@ def cast_to(weight, dtype, device): return r -class AutoWrappedModule(torch.nn.Module): - def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): +class AutoTorchModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def check_free_vram(self): + used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3) + return used_memory < self.vram_limit + + def offload(self): + if self.state != 0: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state != 1: + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def keep(self): + if self.state != 2: + self.to(dtype=self.computation_dtype, device=self.computation_device) + self.state = 2 + + +class AutoWrappedModule(AutoTorchModule): + def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit): super().__init__() self.module = module.to(dtype=offload_dtype, device=offload_device) self.offload_dtype = offload_dtype @@ -18,28 +42,25 @@ class AutoWrappedModule(torch.nn.Module): self.onload_device = onload_device self.computation_dtype = computation_dtype self.computation_device = computation_device + self.vram_limit = vram_limit self.state = 0 - def offload(self): - if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.module.to(dtype=self.offload_dtype, device=self.offload_device) - self.state = 0 - - def onload(self): - if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.module.to(dtype=self.onload_dtype, device=self.onload_device) - self.state = 1 - def forward(self, *args, **kwargs): - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + if self.state == 2: module = self.module else: - module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + module = self.module + elif self.vram_limit is not None and self.check_free_vram(): + self.keep() + module = self.module + else: + module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) return module(*args, **kwargs) -class WanAutoCastLayerNorm(torch.nn.LayerNorm): - def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): +class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule): + def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit): with init_weights_on_device(device=torch.device("meta")): super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) self.weight = module.weight @@ -50,31 +71,28 @@ class WanAutoCastLayerNorm(torch.nn.LayerNorm): self.onload_device = onload_device self.computation_dtype = computation_dtype self.computation_device = computation_device + self.vram_limit = vram_limit self.state = 0 - def offload(self): - if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.to(dtype=self.offload_dtype, device=self.offload_device) - self.state = 0 - - def onload(self): - if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.to(dtype=self.onload_dtype, device=self.onload_device) - self.state = 1 - def forward(self, x, *args, **kwargs): - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + if self.state == 2: weight, bias = self.weight, self.bias else: - weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device) - bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + elif self.vram_limit is not None and self.check_free_vram(): + self.keep() + weight, bias = self.weight, self.bias + else: + weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) with torch.amp.autocast(device_type=x.device.type): x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x) return x -class AutoWrappedLinear(torch.nn.Linear): - def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): +class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): + def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit): with init_weights_on_device(device=torch.device("meta")): super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) self.weight = module.weight @@ -85,28 +103,25 @@ class AutoWrappedLinear(torch.nn.Linear): self.onload_device = onload_device self.computation_dtype = computation_dtype self.computation_device = computation_device + self.vram_limit = vram_limit self.state = 0 - def offload(self): - if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.to(dtype=self.offload_dtype, device=self.offload_device) - self.state = 0 - - def onload(self): - if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): - self.to(dtype=self.onload_dtype, device=self.onload_device) - self.state = 1 - def forward(self, x, *args, **kwargs): - if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + if self.state == 2: weight, bias = self.weight, self.bias else: - weight = cast_to(self.weight, self.computation_dtype, self.computation_device) - bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + elif self.vram_limit is not None and self.check_free_vram(): + self.keep() + weight, bias = self.weight, self.bias + else: + weight = cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) return torch.nn.functional.linear(x, weight, bias) -def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None): for name, module in model.named_children(): for source_module, target_module in module_map.items(): if isinstance(module, source_module): @@ -115,16 +130,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config_ = overflow_module_config else: module_config_ = module_config - module_ = target_module(module, **module_config_) + module_ = target_module(module, **module_config_, vram_limit=vram_limit) setattr(model, name, module_) total_num_param += num_param break else: - total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) + total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit) return total_num_param -def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): - enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) +def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None): + enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit) model.vram_management_enabled = True diff --git a/examples/wanvideo/model_inference/wan_1.3b_speed_control.py b/examples/wanvideo/model_inference/wan_1.3b_speed_control.py new file mode 100644 index 0000000..6efdc65 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_1.3b_speed_control.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=0 +) +save_video(video, "video_slow.mp4", fps=15, quality=5) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=100 +) +save_video(video, "video_fast.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_1.3b_text_to_video.py b/examples/wanvideo/model_inference/wan_1.3b_text_to_video.py new file mode 100644 index 0000000..83e300b --- /dev/null +++ b/examples/wanvideo/model_inference/wan_1.3b_text_to_video.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Video-to-video +video = VideoData("video1.mp4", height=480, width=832) +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_video=video, denoising_strength=0.7, + seed=1, tiled=True +) +save_video(video, "video2.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_1.3b_vace.py b/examples/wanvideo/model_inference/wan_1.3b_vace.py new file mode 100644 index 0000000..99c0242 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_1.3b_vace.py @@ -0,0 +1,52 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video2.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video3.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_14b_flf2v.py b/examples/wanvideo/model_inference/wan_14b_flf2v.py new file mode 100644 index 0000000..3061398 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_14b_flf2v.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] +) + +# First and last frame to video +video = pipe( + prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), + end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), + seed=0, tiled=True, + height=960, width=960, num_frames=33, + sigma_shift=16, +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_14b_image_to_video_480p.py b/examples/wanvideo/model_inference/wan_14b_image_to_video_480p.py new file mode 100644 index 0000000..eb2e5b0 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_14b_image_to_video_480p.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_14b_image_to_video_720p.py b/examples/wanvideo/model_inference/wan_14b_image_to_video_720p.py new file mode 100644 index 0000000..fb14d24 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_14b_image_to_video_720p.py @@ -0,0 +1,35 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True, + height=720, width=1280, +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_14b_text_to_video.py b/examples/wanvideo/model_inference/wan_14b_text_to_video.py new file mode 100644 index 0000000..40cb02d --- /dev/null +++ b/examples/wanvideo/model_inference/wan_14b_text_to_video.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_fun_1.3b_InP.py b/examples/wanvideo/model_inference/wan_fun_1.3b_InP.py new file mode 100644 index 0000000..d921c0c --- /dev/null +++ b/examples/wanvideo/model_inference/wan_fun_1.3b_InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_fun_1.3b_control.py b/examples/wanvideo/model_inference/wan_fun_1.3b_control.py new file mode 100644 index 0000000..43374d2 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_fun_1.3b_control.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_fun_14b_InP.py b/examples/wanvideo/model_inference/wan_fun_14b_InP.py new file mode 100644 index 0000000..af227cb --- /dev/null +++ b/examples/wanvideo/model_inference/wan_fun_14b_InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_fun_14b_control.py b/examples/wanvideo/model_inference/wan_fun_14b_control.py new file mode 100644 index 0000000..db9e5c8 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_fun_14b_control.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_fun_v1.1_1.3b_reference_control.py b/examples/wanvideo/model_inference/wan_fun_v1.1_1.3b_reference_control.py new file mode 100644 index 0000000..0f7e4c8 --- /dev/null +++ b/examples/wanvideo/model_inference/wan_fun_v1.1_1.3b_reference_control.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/wan_fun_v1.1_14b_reference_control.py b/examples/wanvideo/model_inference/wan_fun_v1.1_14b_reference_control.py new file mode 100644 index 0000000..78635ff --- /dev/null +++ b/examples/wanvideo/model_inference/wan_fun_v1.1_14b_reference_control.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/requirements.txt b/requirements.txt index 63a871b..92d8b48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch>=2.0.0 torchvision cupy-cuda12x -transformers==4.46.2 +transformers controlnet-aux==0.0.7 imageio imageio[ffmpeg] @@ -11,3 +11,4 @@ sentencepiece protobuf modelscope ftfy +pynvml diff --git a/test.py b/test.py deleted file mode 100644 index f7959ee..0000000 --- a/test.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -torch.cuda.set_per_process_memory_fraction(0.999, 0) -from diffsynth import ModelManager, save_video, VideoData, save_frames, save_video, download_models -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, model_fn_wan_video -from diffsynth.controlnets.processors import Annotator -from diffsynth.data.video import crop_and_resize -from modelscope import snapshot_download -from tqdm import tqdm -from PIL import Image - - -# Load models -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - # ModelConfig("D:\projects\VideoX-Fun\models\Wan2.1-Fun-V1.1-1.3B-Control\diffusion_pytorch_model.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), - ], -) -pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) - -video = VideoData(rf"D:\pr_projects\20250503_dance\data\双马尾竖屏暴击!你的微笑就是彩虹的微笑♥ - 1.双马尾竖屏暴击!你的微笑就是彩虹的微笑♥(Av114086629088385,P1).mp4", height=832, width=480) -annotator = Annotator("openpose") -video = [video[i] for i in tqdm(range(450, 450+1*81, 1))] -save_video(video, "video_input.mp4", fps=60, quality=5) -control_video = [annotator(f) for f in tqdm(video)] -save_video(control_video, "video_control.mp4", fps=60, quality=5) -reference_image = crop_and_resize(Image.open(rf"D:\pr_projects\20250503_dance\data\marmot4.png"), 832, 480) - -with torch.amp.autocast("cuda", torch.bfloat16): - video = pipe( - prompt="微距摄影风格特写画面,一只憨态可掬的土拨鼠正用后腿站立在碎石堆上,它在挥舞着双臂。金棕色的绒毛在阳光下泛着丝绸般的光泽,腹部毛发呈现浅杏色渐变,每根毛尖都闪烁着细密的光晕。两只黑曜石般的眼睛透出机警而温顺的光芒,鼻梁两侧的白色触须微微颤动,捕捉着空气中的气息。背景是虚化的灰绿色渐变,几簇嫩绿苔藓从画面右下角探出头来,与前景散落的鹅卵石形成微妙的景深对比。土拨鼠圆润的身形在逆光中勾勒出柔和的轮廓,耳朵紧贴头部的姿态流露出戒备中的天真,整个画面洋溢着自然界生灵特有的灵动与纯真。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=43, tiled=True, - height=832, width=480, num_frames=len(control_video), - control_video=control_video, reference_image=reference_image, - # sliding_window_size=5, sliding_window_stride=2, - # num_inference_steps=100, - # cfg_merge=True, - sigma_shift=16, - ) - save_video(video, "video1.mp4", fps=60, quality=5) diff --git a/train.py b/train.py new file mode 100644 index 0000000..de8afd1 --- /dev/null +++ b/train.py @@ -0,0 +1,75 @@ +import torch, os +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +class WanTrainingModule(DiffusionTrainingModule): + def __init__(self, model_paths): + super().__init__() + self.pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cpu", + model_configs=[ModelConfig(path=path) for path in model_paths], + ) + self.pipe.freeze_except([]) + self.pipe.dit = self.add_lora_to_model(self.pipe.dit, target_modules="q,k,v,o,ffn.0,ffn.2".split(","), lora_alpha=16) + + + def forward_preprocess(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {} + inputs_shared = { + "input_video": data["video"], + "height": data["video"][0].size[1], + "width": data["video"][0].size[0], + "num_frames": len(data["video"]), + # Please do not modify the following parameters. + "cfg_scale": 1, + "tiled": False, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": True, + "cfg_merge": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + return {**inputs_shared, **inputs_posi} + + + def forward(self, data): + inputs = self.forward_preprocess(data) + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + loss = self.pipe.training_loss(**models, **inputs) + return loss + + + +def add_general_parsers(parser): + parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.") + parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.") + parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.") + parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.") + parser.add_argument("--model_paths", type=str, default="", help="Model paths to be loaded. Separated by commas.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") + return parser + + +if __name__ == "__main__": + dataset = VideoDataset( + base_path="data/pixabay100/train", + metadata_path="data/pixabay100/metadata_example.csv", + height=480, width=832, + data_file_keys=["video"], + repeat=400, + ) + model = WanTrainingModule([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + launch_training_task(model, dataset) +