From b1fabbc6b0de1414b8a2b483b46b9a2d97c3aa43 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 18 Mar 2025 17:24:39 +0800 Subject: [PATCH] skip bad videos --- examples/wanvideo/train_wan_t2v.py | 31 ++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 8e8b370..45a026f 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -95,17 +95,20 @@ class TextVideoDataset(torch.utils.data.Dataset): def __getitem__(self, data_id): text = self.text[data_id] path = self.path[data_id] - if self.is_image(path): + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) if self.is_i2v: - raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") - video = self.load_image(path) - else: - video = self.load_video(path) - if self.is_i2v: - video, first_frame = video - data = {"text": text, "video": video, "path": path, "first_frame": first_frame} - else: - data = {"text": text, "video": video, "path": path} + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + except: + data = None return data @@ -127,7 +130,10 @@ class LightningModelForDataProcess(pl.LightningModule): self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} def test_step(self, batch, batch_idx): - text, video, path = batch["text"][0], batch["video"], batch["path"][0] + data = batch[0] + if data is None or data["video"] is None: + return + text, video, path = data["text"], data["video"].unsqueeze(0), data["path"] self.pipe.device = self.device if video is not None: @@ -512,7 +518,8 @@ def data_process(args): dataset, shuffle=False, batch_size=1, - num_workers=args.dataloader_num_workers + num_workers=args.dataloader_num_workers, + collate_fn=lambda x: x, ) model = LightningModelForDataProcess( text_encoder_path=args.text_encoder_path,