controlnet

This commit is contained in:
Artiprocher
2025-03-21 11:09:56 +08:00
parent 6cd032e846
commit 105eaf0f49
4 changed files with 915 additions and 12 deletions

View File

@@ -14,7 +14,10 @@ import numpy as np
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
if os.path.exists(os.path.join(base_path, "train")):
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
else:
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
@@ -359,6 +362,12 @@ def parse_args():
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--metadata_path",
type=str,
default=None,
help="Path to metadata.csv.",
)
parser.add_argument(
"--redirected_tensor_path",
type=str,
@@ -548,7 +557,7 @@ def parse_args():
def data_process(args):
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
@@ -584,7 +593,7 @@ def data_process(args):
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
steps_per_epoch=args.steps_per_epoch,
redirected_tensor_path=args.redirected_tensor_path,
)