mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
Compare commits
11 Commits
refactor
...
wan-refact
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8dd24169cc | ||
|
|
7e6a3c7897 | ||
|
|
eb0060517a | ||
|
|
f157bec8f4 | ||
|
|
27fb9d5ce4 | ||
|
|
89e0eb8796 | ||
|
|
7dc273c020 | ||
|
|
b25c66b303 | ||
|
|
6a833c7134 | ||
|
|
7d29ee1fbb | ||
|
|
98016b2a76 |
@@ -143,8 +143,10 @@ class BasePipeline(torch.nn.Module):
|
|||||||
self.vram_management_enabled = True
|
self.vram_management_enabled = True
|
||||||
|
|
||||||
|
|
||||||
def get_vram(self):
|
def get_free_vram(self):
|
||||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
allocated_memory = torch.cuda.device_memory_used(self.device)
|
||||||
|
return (total_memory - allocated_memory) / (1024 ** 3)
|
||||||
|
|
||||||
|
|
||||||
def freeze_except(self, model_names):
|
def freeze_except(self, model_names):
|
||||||
@@ -245,7 +247,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
vram_limit = None
|
vram_limit = None
|
||||||
else:
|
else:
|
||||||
if vram_limit is None:
|
if vram_limit is None:
|
||||||
vram_limit = self.get_vram()
|
vram_limit = self.get_free_vram()
|
||||||
vram_limit = vram_limit - vram_buffer
|
vram_limit = vram_limit - vram_buffer
|
||||||
if self.text_encoder is not None:
|
if self.text_encoder is not None:
|
||||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||||
|
|||||||
@@ -11,9 +11,8 @@ class VideoDataset(torch.utils.data.Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_path=None, metadata_path=None,
|
base_path=None, metadata_path=None,
|
||||||
num_frames=81,
|
frame_interval=1, num_frames=81,
|
||||||
time_division_factor=4, time_division_remainder=1,
|
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None,
|
||||||
max_pixels=1920*1080, height=None, width=None,
|
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
data_file_keys=("video",),
|
data_file_keys=("video",),
|
||||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||||
@@ -26,15 +25,17 @@ class VideoDataset(torch.utils.data.Dataset):
|
|||||||
metadata_path = args.dataset_metadata_path
|
metadata_path = args.dataset_metadata_path
|
||||||
height = args.height
|
height = args.height
|
||||||
width = args.width
|
width = args.width
|
||||||
max_pixels = args.max_pixels
|
|
||||||
num_frames = args.num_frames
|
num_frames = args.num_frames
|
||||||
data_file_keys = args.data_file_keys.split(",")
|
data_file_keys = args.data_file_keys.split(",")
|
||||||
repeat = args.dataset_repeat
|
repeat = args.dataset_repeat
|
||||||
|
|
||||||
|
metadata = pd.read_csv(metadata_path)
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
|
||||||
self.base_path = base_path
|
self.base_path = base_path
|
||||||
|
self.frame_interval = frame_interval
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.time_division_factor = time_division_factor
|
self.dynamic_resolution = dynamic_resolution
|
||||||
self.time_division_remainder = time_division_remainder
|
|
||||||
self.max_pixels = max_pixels
|
self.max_pixels = max_pixels
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
@@ -45,43 +46,9 @@ class VideoDataset(torch.utils.data.Dataset):
|
|||||||
self.video_file_extension = video_file_extension
|
self.video_file_extension = video_file_extension
|
||||||
self.repeat = repeat
|
self.repeat = repeat
|
||||||
|
|
||||||
if height is not None and width is not None:
|
if height is not None and width is not None and dynamic_resolution == True:
|
||||||
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||||
self.dynamic_resolution = False
|
self.dynamic_resolution = False
|
||||||
elif height is None and width is None:
|
|
||||||
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
|
||||||
self.dynamic_resolution = True
|
|
||||||
|
|
||||||
if metadata_path is None:
|
|
||||||
print("No metadata. Trying to generate it.")
|
|
||||||
metadata = self.generate_metadata(base_path)
|
|
||||||
print(f"{len(metadata)} lines in metadata.")
|
|
||||||
else:
|
|
||||||
metadata = pd.read_csv(metadata_path)
|
|
||||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_metadata(self, folder):
|
|
||||||
video_list, prompt_list = [], []
|
|
||||||
file_set = set(os.listdir(folder))
|
|
||||||
for file_name in file_set:
|
|
||||||
if "." not in file_name:
|
|
||||||
continue
|
|
||||||
file_ext_name = file_name.split(".")[-1].lower()
|
|
||||||
file_base_name = file_name[:-len(file_ext_name)-1]
|
|
||||||
if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension:
|
|
||||||
continue
|
|
||||||
prompt_file_name = file_base_name + ".txt"
|
|
||||||
if prompt_file_name not in file_set:
|
|
||||||
continue
|
|
||||||
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
|
|
||||||
prompt = f.read().strip()
|
|
||||||
video_list.append(file_name)
|
|
||||||
prompt_list.append(prompt)
|
|
||||||
metadata = pd.DataFrame()
|
|
||||||
metadata["video"] = video_list
|
|
||||||
metadata["prompt"] = prompt_list
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
def crop_and_resize(self, image, target_height, target_width):
|
def crop_and_resize(self, image, target_height, target_width):
|
||||||
@@ -109,21 +76,14 @@ class VideoDataset(torch.utils.data.Dataset):
|
|||||||
return height, width
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
def get_num_frames(self, reader):
|
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames):
|
||||||
num_frames = self.num_frames
|
|
||||||
if int(reader.count_frames()) < num_frames:
|
|
||||||
num_frames = int(reader.count_frames())
|
|
||||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
|
||||||
num_frames -= 1
|
|
||||||
return num_frames
|
|
||||||
|
|
||||||
|
|
||||||
def load_video(self, file_path):
|
|
||||||
reader = imageio.get_reader(file_path)
|
reader = imageio.get_reader(file_path)
|
||||||
num_frames = self.get_num_frames(reader)
|
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||||
|
reader.close()
|
||||||
|
return None
|
||||||
frames = []
|
frames = []
|
||||||
for frame_id in range(num_frames):
|
for frame_id in range(num_frames):
|
||||||
frame = reader.get_data(frame_id)
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||||
frame = Image.fromarray(frame)
|
frame = Image.fromarray(frame)
|
||||||
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
@@ -137,6 +97,11 @@ class VideoDataset(torch.utils.data.Dataset):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, file_path):
|
||||||
|
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def is_image(self, file_path):
|
def is_image(self, file_path):
|
||||||
file_ext_name = file_path.split(".")[-1]
|
file_ext_name = file_path.split(".")[-1]
|
||||||
return file_ext_name.lower() in self.image_file_extension
|
return file_ext_name.lower() in self.image_file_extension
|
||||||
@@ -217,50 +182,34 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLogger:
|
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None):
|
||||||
def __init__(self, output_path, remove_prefix_in_ckpt=None):
|
if args is not None:
|
||||||
self.output_path = output_path
|
learning_rate = args.learning_rate
|
||||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
num_epochs = args.num_epochs
|
||||||
|
output_path = args.output_path
|
||||||
|
remove_prefix_in_ckpt = args.remove_prefix_in_ckpt
|
||||||
def on_step_end(self, loss):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
state_dict = accelerator.get_state_dict(model)
|
|
||||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
|
||||||
os.makedirs(self.output_path, exist_ok=True)
|
|
||||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
|
||||||
accelerator.save(state_dict, path, safe_serialization=True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def launch_training_task(
|
|
||||||
dataset: torch.utils.data.Dataset,
|
|
||||||
model: DiffusionTrainingModule,
|
|
||||||
model_logger: ModelLogger,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
||||||
num_epochs: int = 1,
|
|
||||||
gradient_accumulation_steps: int = 1,
|
|
||||||
):
|
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
||||||
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
|
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=1)
|
||||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
|
||||||
for epoch_id in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for data in tqdm(dataloader):
|
for data in tqdm(dataloader):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = model(data)
|
loss = model(data)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(loss)
|
scheduler.step()
|
||||||
scheduler.step()
|
accelerator.wait_for_everyone()
|
||||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt)
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
path = os.path.join(output_path, f"epoch-{epoch}.safetensors")
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -279,9 +228,8 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat
|
|||||||
|
|
||||||
def wan_parser():
|
def wan_parser():
|
||||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.")
|
||||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.")
|
||||||
parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
|
||||||
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||||
@@ -299,6 +247,5 @@ def wan_parser():
|
|||||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ class AutoTorchModule(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def check_free_vram(self):
|
def check_free_vram(self):
|
||||||
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
|
||||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
|
|
||||||
return used_memory < self.vram_limit
|
return used_memory < self.vram_limit
|
||||||
|
|
||||||
def offload(self):
|
def offload(self):
|
||||||
|
|||||||
@@ -4,10 +4,6 @@
|
|||||||
|
|
||||||
Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba.
|
Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba.
|
||||||
|
|
||||||
**DiffSynth-Studio has adopted a new inference and training framework. To use the previous version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).**
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
Before using this model, please install DiffSynth-Studio from **source code**.
|
Before using this model, please install DiffSynth-Studio from **source code**.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -16,8 +12,6 @@ cd DiffSynth-Studio
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
@@ -289,7 +283,7 @@ video2.mp4,"a dog is running"
|
|||||||
We have prepared a sample video dataset to help you test. You can download it using the following command:
|
We have prepared a sample video dataset to help you test. You can download it using the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
The dataset supports mixed training of videos and images. Supported video formats include `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`, and supported image formats include `"jpg", "jpeg", "png", "webp"`.
|
The dataset supports mixed training of videos and images. Supported video formats include `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`, and supported image formats include `"jpg", "jpeg", "png", "webp"`.
|
||||||
@@ -393,25 +387,3 @@ Note that full fine-tuning of the 14B model requires 8 GPUs, each with at least
|
|||||||
The default video resolution in the training script is `480*832*81`. Increasing the resolution may cause out-of-memory errors. To reduce VRAM usage, add the parameter `--use_gradient_checkpointing_offload`.
|
The default video resolution in the training script is `480*832*81`. Increasing the resolution may cause out-of-memory errors. To reduce VRAM usage, add the parameter `--use_gradient_checkpointing_offload`.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Gallery
|
|
||||||
|
|
||||||
1.3B text-to-video:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
|
||||||
|
|
||||||
Put sunglasses on the dog (1.3B video-to-video):
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
|
||||||
|
|
||||||
14B text-to-video:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
|
||||||
|
|
||||||
14B image-to-video:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
|
||||||
|
|
||||||
LoRA training:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9
|
|
||||||
|
|||||||
@@ -4,10 +4,6 @@
|
|||||||
|
|
||||||
Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。
|
Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。
|
||||||
|
|
||||||
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
|
|
||||||
|
|
||||||
## 安装
|
|
||||||
|
|
||||||
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
|
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -16,8 +12,6 @@ cd DiffSynth-Studio
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
## 模型总览
|
|
||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
@@ -292,7 +286,7 @@ video2.mp4,"a dog is running"
|
|||||||
我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
数据集支持视频和图片混合训练,支持的视频文件格式包括 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`,支持的图片格式包括 `"jpg", "jpeg", "png", "webp"`。
|
数据集支持视频和图片混合训练,支持的视频文件格式包括 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`,支持的图片格式包括 `"jpg", "jpeg", "png", "webp"`。
|
||||||
@@ -396,25 +390,3 @@ model_configs=[
|
|||||||
训练脚本的默认视频尺寸为 `480*832*81`,提升分辨率将可能导致显存不足,请添加参数 `--use_gradient_checkpointing_offload` 降低显存占用。
|
训练脚本的默认视频尺寸为 `480*832*81`,提升分辨率将可能导致显存不足,请添加参数 `--use_gradient_checkpointing_offload` 降低显存占用。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## 案例展示
|
|
||||||
|
|
||||||
1.3B 文生视频:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
|
||||||
|
|
||||||
给狗狗戴上墨镜(1.3B 视频生视频):
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
|
||||||
|
|
||||||
14B 文生视频:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
|
||||||
|
|
||||||
14B 图生视频:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
|
||||||
|
|
||||||
LoRA 训练:
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
|||||||
--learning_rate 1e-5 \
|
--learning_rate 1e-5 \
|
||||||
--num_epochs 5 \
|
--num_epochs 5 \
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora" \
|
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full" \
|
||||||
--lora_base_model "dit" \
|
--lora_base_model "dit" \
|
||||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
--lora_rank 32 \
|
--lora_rank 32 \
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
|||||||
--learning_rate 1e-5 \
|
--learning_rate 1e-5 \
|
||||||
--num_epochs 5 \
|
--num_epochs 5 \
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_lora" \
|
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full" \
|
||||||
--lora_base_model "dit" \
|
--lora_base_model "dit" \
|
||||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
--lora_rank 32 \
|
--lora_rank 32 \
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch, os, json
|
import torch, os, json
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -107,14 +107,4 @@ if __name__ == "__main__":
|
|||||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
extra_inputs=args.extra_inputs,
|
extra_inputs=args.extra_inputs,
|
||||||
)
|
)
|
||||||
model_logger = ModelLogger(
|
launch_training_task(model, dataset, args=args)
|
||||||
args.output_path,
|
|
||||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
|
|
||||||
)
|
|
||||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
|
||||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
|
||||||
launch_training_task(
|
|
||||||
dataset, model, model_logger, optimizer, scheduler,
|
|
||||||
num_epochs=args.num_epochs,
|
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ pipe = WanVideoPipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
state_dict = load_state_dict("models/train/Wan2.1-VACE-1.3B-Preview_full/epoch-1.safetensors")
|
state_dict = load_state_dict("models/train/VACE-Wan2.1-1.3B-Preview_full/epoch-1.safetensors")
|
||||||
pipe.vace.load_state_dict(state_dict)
|
pipe.vace.load_state_dict(state_dict)
|
||||||
pipe.enable_vram_management()
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user