mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
258 lines
12 KiB
Python
258 lines
12 KiB
Python
import imageio, os, torch, warnings, torchvision, argparse
|
|
from peft import LoraConfig, inject_adapter_in_model
|
|
from PIL import Image
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from accelerate import Accelerator
|
|
|
|
|
|
|
|
class VideoDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
base_path=None, metadata_path=None,
|
|
frame_interval=1, num_frames=81,
|
|
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None,
|
|
height_division_factor=16, width_division_factor=16,
|
|
data_file_keys=("video",),
|
|
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
|
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
|
|
repeat=1,
|
|
args=None,
|
|
):
|
|
if args is not None:
|
|
base_path = args.dataset_base_path
|
|
metadata_path = args.dataset_metadata_path
|
|
height = args.height
|
|
width = args.width
|
|
num_frames = args.num_frames
|
|
data_file_keys = args.data_file_keys.split(",")
|
|
repeat = args.dataset_repeat
|
|
|
|
metadata = pd.read_csv(metadata_path)
|
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
|
|
|
self.base_path = base_path
|
|
self.frame_interval = frame_interval
|
|
self.num_frames = num_frames
|
|
self.dynamic_resolution = dynamic_resolution
|
|
self.max_pixels = max_pixels
|
|
self.height = height
|
|
self.width = width
|
|
self.height_division_factor = height_division_factor
|
|
self.width_division_factor = width_division_factor
|
|
self.data_file_keys = data_file_keys
|
|
self.image_file_extension = image_file_extension
|
|
self.video_file_extension = video_file_extension
|
|
self.repeat = repeat
|
|
|
|
if height is not None and width is not None and dynamic_resolution == True:
|
|
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
|
self.dynamic_resolution = False
|
|
|
|
|
|
def crop_and_resize(self, image, target_height, target_width):
|
|
width, height = image.size
|
|
scale = max(target_width / width, target_height / height)
|
|
image = torchvision.transforms.functional.resize(
|
|
image,
|
|
(round(height*scale), round(width*scale)),
|
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
|
)
|
|
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
|
return image
|
|
|
|
|
|
def get_height_width(self, image):
|
|
if self.dynamic_resolution:
|
|
width, height = image.size
|
|
if width * height > self.max_pixels:
|
|
scale = (width * height / self.max_pixels) ** 0.5
|
|
height, width = int(height / scale), int(width / scale)
|
|
height = height // self.height_division_factor * self.height_division_factor
|
|
width = width // self.width_division_factor * self.width_division_factor
|
|
else:
|
|
height, width = self.height, self.width
|
|
return height, width
|
|
|
|
|
|
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames):
|
|
reader = imageio.get_reader(file_path)
|
|
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
|
reader.close()
|
|
return None
|
|
frames = []
|
|
for frame_id in range(num_frames):
|
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
|
frame = Image.fromarray(frame)
|
|
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
|
frames.append(frame)
|
|
reader.close()
|
|
return frames
|
|
|
|
|
|
def load_image(self, file_path):
|
|
image = Image.open(file_path).convert("RGB")
|
|
image = self.crop_and_resize(image, *self.get_height_width(image))
|
|
return image
|
|
|
|
|
|
def load_video(self, file_path):
|
|
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
|
|
return frames
|
|
|
|
|
|
def is_image(self, file_path):
|
|
file_ext_name = file_path.split(".")[-1]
|
|
return file_ext_name.lower() in self.image_file_extension
|
|
|
|
|
|
def is_video(self, file_path):
|
|
file_ext_name = file_path.split(".")[-1]
|
|
return file_ext_name.lower() in self.video_file_extension
|
|
|
|
|
|
def load_data(self, file_path):
|
|
if self.is_image(file_path):
|
|
return self.load_image(file_path)
|
|
elif self.is_video(file_path):
|
|
return self.load_video(file_path)
|
|
else:
|
|
return None
|
|
|
|
|
|
def __getitem__(self, data_id):
|
|
data = self.data[data_id % len(self.data)].copy()
|
|
for key in self.data_file_keys:
|
|
if key in data:
|
|
path = os.path.join(self.base_path, data[key])
|
|
data[key] = self.load_data(path)
|
|
if data[key] is None:
|
|
warnings.warn(f"cannot load file {data[key]}.")
|
|
return None
|
|
return data
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.data) * self.repeat
|
|
|
|
|
|
|
|
class DiffusionTrainingModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
|
|
def to(self, *args, **kwargs):
|
|
for name, model in self.named_children():
|
|
model.to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def trainable_modules(self):
|
|
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
|
return trainable_modules
|
|
|
|
|
|
def trainable_param_names(self):
|
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
|
return trainable_param_names
|
|
|
|
|
|
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
|
|
if lora_alpha is None:
|
|
lora_alpha = lora_rank
|
|
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
|
model = inject_adapter_in_model(lora_config, model)
|
|
return model
|
|
|
|
|
|
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
|
trainable_param_names = self.trainable_param_names()
|
|
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
|
if remove_prefix is not None:
|
|
state_dict_ = {}
|
|
for name, param in state_dict.items():
|
|
if name.startswith(remove_prefix):
|
|
name = name[len(remove_prefix):]
|
|
state_dict_[name] = param
|
|
state_dict = state_dict_
|
|
return state_dict
|
|
|
|
|
|
|
|
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None):
|
|
if args is not None:
|
|
learning_rate = args.learning_rate
|
|
num_epochs = args.num_epochs
|
|
output_path = args.output_path
|
|
remove_prefix_in_ckpt = args.remove_prefix_in_ckpt
|
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
|
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
|
|
|
accelerator = Accelerator(gradient_accumulation_steps=1)
|
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
|
|
|
for epoch in range(num_epochs):
|
|
for data in tqdm(dataloader):
|
|
with accelerator.accumulate(model):
|
|
optimizer.zero_grad()
|
|
loss = model(data)
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
state_dict = accelerator.get_state_dict(model)
|
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt)
|
|
os.makedirs(output_path, exist_ok=True)
|
|
path = os.path.join(output_path, f"epoch-{epoch}.safetensors")
|
|
accelerator.save(state_dict, path, safe_serialization=True)
|
|
|
|
|
|
|
|
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
|
|
accelerator = Accelerator()
|
|
model, dataloader = accelerator.prepare(model, dataloader)
|
|
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
|
|
for data_id, data in enumerate(tqdm(dataloader)):
|
|
with torch.no_grad():
|
|
inputs = model.forward_preprocess(data)
|
|
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
|
|
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
|
|
|
|
|
|
|
|
def wan_parser():
|
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.")
|
|
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.")
|
|
parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.")
|
|
parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.")
|
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in each video. The frames are sampled from the prefix.")
|
|
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.")
|
|
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.")
|
|
parser.add_argument("--model_paths", type=str, default=None, help="Model paths to be loaded. JSON format.")
|
|
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas.")
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
|
parser.add_argument("--output_path", type=str, default="./models", help="Save path.")
|
|
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
|
parser.add_argument("--trainable_models", type=str, default=None, help="Trainable models, e.g., dit, vae, text_encoder.")
|
|
parser.add_argument("--lora_base_model", type=str, default=None, help="Add LoRA on which model.")
|
|
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Add LoRA on which layer.")
|
|
parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank.")
|
|
parser.add_argument("--input_contains_input_image", default=False, action="store_true", help="Model input contains 'input_image'.")
|
|
parser.add_argument("--input_contains_end_image", default=False, action="store_true", help="Model input contains 'end_image'.")
|
|
parser.add_argument("--input_contains_control_video", default=False, action="store_true", help="Model input contains 'control_video'.")
|
|
parser.add_argument("--input_contains_reference_image", default=False, action="store_true", help="Model input contains 'reference_image'.")
|
|
parser.add_argument("--input_contains_vace_video", default=False, action="store_true", help="Model input contains 'vace_video'.")
|
|
parser.add_argument("--input_contains_vace_reference_image", default=False, action="store_true", help="Model input contains 'vace_reference_image'.")
|
|
parser.add_argument("--input_contains_motion_bucket_id", default=False, action="store_true", help="Model input contains 'motion_bucket_id'.")
|
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Offload gradient checkpointing to RAM.")
|
|
return parser
|
|
|