From 9894e27af82a03ac10e932c9e05c79cb939016a0 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 21 Jun 2024 11:29:17 +0800 Subject: [PATCH] ExVideo training --- README.md | 40 +- examples/ExVideo/ExVideo_ema.py | 64 ++++ .../{ExVideo_svd.py => ExVideo_svd_test.py} | 26 +- examples/ExVideo/ExVideo_svd_train.py | 362 ++++++++++++++++++ examples/ExVideo/README.md | 69 +++- 5 files changed, 545 insertions(+), 16 deletions(-) create mode 100644 examples/ExVideo/ExVideo_ema.py rename examples/ExVideo/{ExVideo_svd.py => ExVideo_svd_test.py} (52%) create mode 100644 examples/ExVideo/ExVideo_svd_train.py diff --git a/README.md b/README.md index de0d6d7..12365a5 100644 --- a/README.md +++ b/README.md @@ -6,18 +6,18 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu ## Roadmap -* Aug 29, 2023. I propose DiffSynth, a video synthesis framework. +* Aug 29, 2023. We propose DiffSynth, a video synthesis framework. * [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/). * The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth). * The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463). -* Oct 1, 2023. I release an early version of this project, namely FastSDXL. A try for building a diffusion engine. +* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine. * The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL). * FastSDXL includes a trainable OLSS scheduler for efficiency improvement. * The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler). * The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677). * A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj). * Since OLSS requires additional training, we don't implement it in this project. -* Nov 15, 2023. I propose FastBlend, a powerful video deflickering algorithm. +* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm. * The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend). * Demo videos are shown on Bilibili, including three tasks. * [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE) @@ -25,11 +25,17 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu * [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF) * The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265). * An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend). -* Dec 8, 2023. I decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. -* Jan 29, 2024. I propose Diffutoon, a fantastic solution for toon shading. +* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started. +* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading. * [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/). * The source codes are released in this project. * The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224). +* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance. +* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames. + * [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/). + * Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/). + * Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1). + * Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130). * Until now, DiffSynth Studio has supported the following models: * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) * [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) @@ -39,6 +45,8 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu * [ESRGAN](https://github.com/xinntao/ESRGAN) * [RIFE](https://github.com/hzwer/ECCV2022-RIFE) * [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT) + * [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) + * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) ## Installation @@ -56,18 +64,16 @@ Enter the Python environment: conda activate DiffSynthStudio ``` -## Usage (in WebUI) - -``` -python -m streamlit run DiffSynth_Studio.py -``` - -https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954 - ## Usage (in Python code) The Python examples are in [`examples`](./examples/). We provide an overview here. +### Long Video Synthesis + +We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) + +https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc + ### Image Synthesis Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/) @@ -109,3 +115,11 @@ Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山 |Without LoRA|With LoRA| |-|-| |![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)| + +## Usage (in WebUI) + +``` +python -m streamlit run DiffSynth_Studio.py +``` + +https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954 diff --git a/examples/ExVideo/ExVideo_ema.py b/examples/ExVideo/ExVideo_ema.py new file mode 100644 index 0000000..81d4c39 --- /dev/null +++ b/examples/ExVideo/ExVideo_ema.py @@ -0,0 +1,64 @@ +import torch, os, argparse +from safetensors.torch import save_file + + +def load_pl_state_dict(file_path): + print(f"loading {file_path}") + state_dict = torch.load(file_path, map_location="cpu") + trainable_param_names = set(state_dict["trainable_param_names"]) + if "module" in state_dict: + state_dict = state_dict["module"] + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith("_forward_module."): + name = name[len("_forward_module."):] + if name.startswith("unet."): + name = name[len("unet."):] + if name in trainable_param_names: + state_dict_[name] = param + return state_dict_ + + +def ckpt_to_epochs(ckpt_name): + return int(ckpt_name.split("=")[1].split("-")[0]) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.9, + help="Gamma in EMA.", + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + # args + args = parse_args() + folder = args.output_path + gamma = args.gamma + + # EMA + ckpt_list = sorted([(ckpt_to_epochs(ckpt_name), ckpt_name) for ckpt_name in os.listdir(folder) if os.path.isdir(f"{folder}/{ckpt_name}")]) + state_dict_ema = None + for epochs, ckpt_name in ckpt_list: + state_dict = load_pl_state_dict(f"{folder}/{ckpt_name}/checkpoint/mp_rank_00_model_states.pt") + if state_dict_ema is None: + state_dict_ema = {name: param.float() for name, param in state_dict.items()} + else: + for name, param in state_dict.items(): + state_dict_ema[name] = state_dict_ema[name] * gamma + param.float() * (1 - gamma) + save_path = ckpt_name.replace(".ckpt", "-ema.safetensors") + print(f"save to {folder}/{save_path}") + save_file(state_dict_ema, f"{folder}/{save_path}") diff --git a/examples/ExVideo/ExVideo_svd.py b/examples/ExVideo/ExVideo_svd_test.py similarity index 52% rename from examples/ExVideo/ExVideo_svd.py rename to examples/ExVideo/ExVideo_svd_test.py index cfe854a..25afd5a 100644 --- a/examples/ExVideo/ExVideo_svd.py +++ b/examples/ExVideo/ExVideo_svd_test.py @@ -3,6 +3,30 @@ from diffsynth import ModelManager import torch, os +# Download models (from Huggingface) +# Text-to-image model: +# `models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/clip_text_encoder/pytorch_model.bin) +# `models/HunyuanDiT/t2i/mt5/pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/mt5/pytorch_model.bin) +# `models/HunyuanDiT/t2i/model/pytorch_model_ema.pt`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/model/pytorch_model_ema.pt) +# `models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin) +# Stable Video Diffusion model: +# `models/stable_video_diffusion/svd_xt.safetensors`: [link](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors) +# ExVideo extension blocks: +# `models/stable_video_diffusion/model.fp16.safetensors`: [link](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1/resolve/main/model.fp16.safetensors) + + +# Download models (from Modelscope) +# Text-to-image model: +# `models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin`: [link](https://www.modelscope.cn/api/v1/models/modelscope/HunyuanDiT/repo?Revision=master&FilePath=t2i%2Fclip_text_encoder%2Fpytorch_model.bin) +# `models/HunyuanDiT/t2i/mt5/pytorch_model.bin`: [link](https://www.modelscope.cn/api/v1/models/modelscope/HunyuanDiT/repo?Revision=master&FilePath=t2i%2Fmt5%2Fpytorch_model.bin) +# `models/HunyuanDiT/t2i/model/pytorch_model_ema.pt`: [link](https://www.modelscope.cn/api/v1/models/modelscope/HunyuanDiT/repo?Revision=master&FilePath=t2i%2Fmodel%2Fpytorch_model_ema.pt) +# `models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin`: [link](https://www.modelscope.cn/api/v1/models/modelscope/HunyuanDiT/repo?Revision=master&FilePath=t2i%2Fsdxl-vae-fp16-fix%2Fdiffusion_pytorch_model.bin) +# Stable Video Diffusion model: +# `models/stable_video_diffusion/svd_xt.safetensors`: [link](https://www.modelscope.cn/api/v1/models/AI-ModelScope/stable-video-diffusion-img2vid-xt/repo?Revision=master&FilePath=svd_xt.safetensors) +# ExVideo extension blocks: +# `models/stable_video_diffusion/model.fp16.safetensors`: [link](https://modelscope.cn/api/v1/models/ECNU-CILab/ExVideo-SVD-128f-v1/repo?Revision=master&FilePath=model.fp16.safetensors) + + def generate_image(): # Load models os.environ["TOKENIZERS_PARALLELISM"] = "True" @@ -51,7 +75,7 @@ def upscale_video(image, video): model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") model_manager.load_models([ "models/stable_video_diffusion/svd_xt.safetensors", - "models/stable_video_diffusion/model.fp16.safetensors" + "models/stable_video_diffusion/model.fp16.safetensors", ]) pipe = SVDVideoPipeline.from_model_manager(model_manager) diff --git a/examples/ExVideo/ExVideo_svd_train.py b/examples/ExVideo/ExVideo_svd_train.py new file mode 100644 index 0000000..b3cb72c --- /dev/null +++ b/examples/ExVideo/ExVideo_svd_train.py @@ -0,0 +1,362 @@ +import torch, json, os, imageio, argparse +from torchvision.transforms import v2 +import numpy as np +from einops import rearrange, repeat +import lightning as pl +from diffsynth import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, ContinuousODEScheduler, load_state_dict +from diffsynth.pipelines.stable_video_diffusion import SVDCLIPImageProcessor +from diffsynth.models.svd_unet import TemporalAttentionBlock + + + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.path = [os.path.join(base_path, i["path"]) for i in metadata] + self.steps_per_epoch = steps_per_epoch + self.training_shapes = training_shapes + + self.frame_process = [] + for max_num_frames, interval, num_frames, height, width in training_shapes: + self.frame_process.append(v2.Compose([ + v2.Resize(size=max(height, width), antialias=True), + v2.CenterCrop(size=(height, width)), + v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]), + ])) + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = torch.tensor(frame, dtype=torch.float32) + frame = rearrange(frame, "H W C -> 1 C H W") + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.concat(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + return frames + + + def load_video(self, file_path, training_shape_id): + data = {} + max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id] + frame_process = self.frame_process[training_shape_id] + start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0] + frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process) + if frames is None: + return None + else: + data[f"frames_{training_shape_id}"] = frames + return data + + + def __getitem__(self, index): + video_data = {} + for training_shape_id in range(len(self.training_shapes)): + while True: + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + video_file = self.path[data_id] + try: + data = self.load_video(video_file, training_shape_id) + except: + data = None + if data is not None: + break + video_data.update(data) + return video_data + + + def __len__(self): + return self.steps_per_epoch + + + +class MotionBucketManager: + def __init__(self): + self.thresholds = [ + 0.000000000, 0.012205946, 0.015117834, 0.018080613, 0.020614484, 0.021959992, 0.024088068, 0.026323952, + 0.028277775, 0.029968588, 0.031836554, 0.033596724, 0.035121530, 0.037200287, 0.038914755, 0.040696491, + 0.042368013, 0.044265781, 0.046311017, 0.048243891, 0.050294187, 0.052142400, 0.053634230, 0.055612389, + 0.057594258, 0.059410289, 0.061283995, 0.063603796, 0.065192916, 0.067146860, 0.069066539, 0.070390493, + 0.072588451, 0.073959745, 0.075889029, 0.077695683, 0.079783581, 0.082162730, 0.084092639, 0.085958421, + 0.087700523, 0.089684933, 0.091688842, 0.093335517, 0.094987206, 0.096664011, 0.098314710, 0.100262381, + 0.101984538, 0.103404313, 0.105280340, 0.106974818, 0.109028399, 0.111164779, 0.113065213, 0.114362158, + 0.116407216, 0.118063427, 0.119524263, 0.121835820, 0.124242283, 0.126202747, 0.128989249, 0.131672353, + 0.133417681, 0.135567948, 0.137313649, 0.139189199, 0.140912935, 0.143525436, 0.145718485, 0.148315132, + 0.151039496, 0.153218940, 0.155252382, 0.157651082, 0.159966752, 0.162195817, 0.164811596, 0.167341709, + 0.170251891, 0.172651157, 0.175550997, 0.178372145, 0.181039348, 0.183565900, 0.186599866, 0.190071866, + 0.192574754, 0.195026234, 0.198099136, 0.200210452, 0.202522039, 0.205410406, 0.208610669, 0.211623028, + 0.214723110, 0.218520239, 0.222194016, 0.225363150, 0.229384825, 0.233422622, 0.237012610, 0.240735114, + 0.243622541, 0.247465774, 0.252190471, 0.257356376, 0.261856794, 0.266556412, 0.271076709, 0.277361482, + 0.281250387, 0.286582440, 0.291158527, 0.296712339, 0.303008437, 0.311793238, 0.318485111, 0.326999635, + 0.332138240, 0.341770738, 0.354188830, 0.365194678, 0.379234344, 0.401538879, 0.416078776, 0.440871328, + ] + + def get_motion_score(self, frames): + score = frames.std(dim=2).mean(dim=[1, 2, 3]).tolist() + return score + + def get_bucket_id(self, motion_score): + for bucket_id in range(len(self.thresholds) - 1): + if self.thresholds[bucket_id + 1] > motion_score: + return bucket_id + return len(self.thresholds) - 1 + + def __call__(self, frames): + scores = self.get_motion_score(frames) + bucket_ids = [self.get_bucket_id(score) for score in scores] + return bucket_ids + + + +class LightningModel(pl.LightningModule): + def __init__(self, learning_rate=1e-5, svd_ckpt_path=None, add_positional_conv=128, contrast_enhance_scale=1.01): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.float16, device=self.device) + model_manager.load_stable_video_diffusion(state_dict=load_state_dict(svd_ckpt_path), add_positional_conv=add_positional_conv) + + self.image_encoder: SVDImageEncoder = model_manager.image_encoder + self.image_encoder.eval() + self.image_encoder.requires_grad_(False) + + self.unet: SVDUNet = model_manager.unet + self.unet.train() + self.unet.requires_grad_(False) + for block in self.unet.blocks: + if isinstance(block, TemporalAttentionBlock): + block.requires_grad_(True) + + self.vae_encoder: SVDVAEEncoder = model_manager.vae_encoder + self.vae_encoder.eval() + self.vae_encoder.requires_grad_(False) + + self.noise_scheduler = ContinuousODEScheduler(num_inference_steps=1000) + self.learning_rate = learning_rate + + self.motion_bucket_manager = MotionBucketManager() + self.contrast_enhance_scale = contrast_enhance_scale + + + def encode_image_with_clip(self, image): + image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.dtype) + std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.dtype) + image = (image - mean) / std + image_emb = self.image_encoder(image) + return image_emb + + + def encode_video_with_vae(self, video): + video = video.to(device=self.device, dtype=self.dtype) + video = video.unsqueeze(0) + latents = self.vae_encoder.encode_video(video) + latents = rearrange(latents[0], "C T H W -> T C H W") + return latents + + + def tensor2video(self, frames): + frames = rearrange(frames, "C T H W -> T H W C") + frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) + return frames + + + def calculate_loss(self, frames): + with torch.no_grad(): + # Call video encoder + latents = self.encode_video_with_vae(frames) + image_emb_vae = repeat(latents[0] / self.vae_encoder.scaling_factor, "C H W -> T C H W", T=frames.shape[1]) + image_emb_clip = self.encode_image_with_clip(frames[:,0].unsqueeze(0)) + + # Call scheduler + timestep = torch.randint(0, len(self.noise_scheduler.timesteps), (1,))[0] + timestep = self.noise_scheduler.timesteps[timestep] + noise = torch.randn_like(latents) + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timestep) + + # Prepare positional id + fps = 30 + motion_bucket_id = self.motion_bucket_manager(frames.unsqueeze(0))[0] + noise_aug_strength = 0 + add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device) + + # Calculate loss + latents_input = torch.cat([noisy_latents, image_emb_vae], dim=1) + model_pred = self.unet(latents_input, timestep, image_emb_clip, add_time_id, use_gradient_checkpointing=True) + latents_output = self.noise_scheduler.step(model_pred.float(), timestep, noisy_latents.float(), to_final=True) + loss = torch.nn.functional.mse_loss(latents_output, latents.float() * self.contrast_enhance_scale, reduction="mean") + + # Re-weighting + reweighted_loss = loss * self.noise_scheduler.training_weight(timestep) + return loss, reweighted_loss + + + def training_step(self, batch, batch_idx): + # Loss + frames = batch["frames_0"][0] + loss, reweighted_loss = self.calculate_loss(frames) + + # Record log + self.log("train_loss", loss, prog_bar=True) + self.log("reweighted_train_loss", reweighted_loss, prog_bar=True) + return reweighted_loss + + + def configure_optimizers(self): + trainable_modules = [] + for block in self.unet.blocks: + if isinstance(block, TemporalAttentionBlock): + trainable_modules += block.parameters() + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.unet.named_parameters())) + trainable_param_names = [named_param[0] for named_param in trainable_param_names] + checkpoint["trainable_param_names"] = trainable_param_names + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_path", + type=str, + default=None, + required=True, + help="Path to pretrained model. For example, `models/stable_video_diffusion/svd_xt.safetensors`.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + required=False, + help="Path to checkpoint, in case your training program is stopped unexpectedly and you want to resume.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=128, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=512, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=512, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=2, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + parser.add_argument( + "--contrast_enhance_scale", + type=float, + default=1.01, + help="Avoid generating gray videos.", + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + # args + args = parse_args() + + # dataset and data loader + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.json"), + training_shapes=[(args.num_frames, 1, args.num_frames, args.height, args.width)], + steps_per_epoch=args.steps_per_epoch, + ) + train_loader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + # We don't support batch_size > 1, + # because sometimes our GPU cannot process even one video. + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + # model + model = LightningModel( + learning_rate=args.learning_rate, + svd_ckpt_path=args.pretrained_path, + add_positional_conv=args.num_frames, + contrast_enhance_scale=args.contrast_enhance_scale + ) + + # train + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + strategy="deepspeed_stage_2", + precision="16-mixed", + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + ) + trainer.fit( + model=model, + train_dataloaders=train_loader, + ckpt_path=args.resume_from_checkpoint + ) diff --git a/examples/ExVideo/README.md b/examples/ExVideo/README.md index e457c75..6b2ed05 100644 --- a/examples/ExVideo/README.md +++ b/examples/ExVideo/README.md @@ -3,8 +3,8 @@ ExVideo is a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames. * [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) -* [Source Code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ExVideo) -* Technical report +* [Source Code](https://github.com/modelscope/DiffSynth-Studio) +* [Technical report](https://arxiv.org/abs/2406.14130) * Extended models * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1) @@ -14,3 +14,68 @@ ExVideo is a post-tuning technique aimed at enhancing the capability of video ge Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd.py](./ExVideo_svd.py). https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc + +## Train + +* Step 1: Install additional packages + +``` +pip install lightning deepspeed +``` + +* Step 2: Download base model (from [HuggingFace](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors) or [ModelScope](https://www.modelscope.cn/api/v1/models/AI-ModelScope/stable-video-diffusion-img2vid-xt/repo?Revision=master&FilePath=svd_xt.safetensors)) to `models/stable_video_diffusion/svd_xt.safetensors`. + +* Step 3: Prepare datasets + +``` +path/to/your/dataset +├── metadata.json +└── videos + ├── video_1.mp4 + ├── video_2.mp4 + └── video_3.mp4 +``` + +where the `metadata.json` is + +``` +[ + { + "path": "videos/video_1.mp4" + }, + { + "path": "videos/video_2.mp4" + }, + { + "path": "videos/video_3.mp4" + } +] +``` + +* Step 4: Run + +``` +CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python -u ExVideo_svd_train.py \ + --pretrained_path "models/stable_video_diffusion/svd_xt.safetensors" \ + --dataset_path "path/to/your/dataset" \ + --output_path "path/to/save/models" \ + --steps_per_epoch 8000 \ + --num_frames 128 \ + --height 512 \ + --width 512 \ + --dataloader_num_workers 2 \ + --learning_rate 1e-5 \ + --max_epochs 100 +``` + +* Step 5: Post-process checkpoints + +Calculate Exponential Moving Average (EMA) and package it using `safetensors`. + +``` +python ExVideo_ema.py --output_path "path/to/save/models/lightning_logs/version_xx" --gamma 0.9 +``` + +* Step 6: Enjoy your model + +The EMA model is at `path/to/save/models/lightning_logs/version_xx/checkpoints/epoch=0-step=25-ema.safetensors`. Load it in [ExVideo_svd_test.py](./ExVideo_svd_test.py) and then enjoy your model.