skip bad videos

This commit is contained in:
Artiprocher
2025-03-18 17:24:39 +08:00
parent e28c246bcc
commit b1fabbc6b0

View File

@@ -95,17 +95,20 @@ class TextVideoDataset(torch.utils.data.Dataset):
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
if self.is_image(path):
try:
if self.is_image(path):
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
if self.is_i2v:
video, first_frame = video
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path}
video, first_frame = video
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path}
except:
data = None
return data
@@ -127,7 +130,10 @@ class LightningModelForDataProcess(pl.LightningModule):
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
def test_step(self, batch, batch_idx):
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
data = batch[0]
if data is None or data["video"] is None:
return
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
self.pipe.device = self.device
if video is not None:
@@ -512,7 +518,8 @@ def data_process(args):
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers
num_workers=args.dataloader_num_workers,
collate_fn=lambda x: x,
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,