mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
skip bad videos
This commit is contained in:
@@ -95,17 +95,20 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
def __getitem__(self, data_id):
|
def __getitem__(self, data_id):
|
||||||
text = self.text[data_id]
|
text = self.text[data_id]
|
||||||
path = self.path[data_id]
|
path = self.path[data_id]
|
||||||
if self.is_image(path):
|
try:
|
||||||
|
if self.is_image(path):
|
||||||
|
if self.is_i2v:
|
||||||
|
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||||
|
video = self.load_image(path)
|
||||||
|
else:
|
||||||
|
video = self.load_video(path)
|
||||||
if self.is_i2v:
|
if self.is_i2v:
|
||||||
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
video, first_frame = video
|
||||||
video = self.load_image(path)
|
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||||
else:
|
else:
|
||||||
video = self.load_video(path)
|
data = {"text": text, "video": video, "path": path}
|
||||||
if self.is_i2v:
|
except:
|
||||||
video, first_frame = video
|
data = None
|
||||||
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
|
||||||
else:
|
|
||||||
data = {"text": text, "video": video, "path": path}
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -127,7 +130,10 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
data = batch[0]
|
||||||
|
if data is None or data["video"] is None:
|
||||||
|
return
|
||||||
|
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
|
||||||
|
|
||||||
self.pipe.device = self.device
|
self.pipe.device = self.device
|
||||||
if video is not None:
|
if video is not None:
|
||||||
@@ -512,7 +518,8 @@ def data_process(args):
|
|||||||
dataset,
|
dataset,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=args.dataloader_num_workers
|
num_workers=args.dataloader_num_workers,
|
||||||
|
collate_fn=lambda x: x,
|
||||||
)
|
)
|
||||||
model = LightningModelForDataProcess(
|
model = LightningModelForDataProcess(
|
||||||
text_encoder_path=args.text_encoder_path,
|
text_encoder_path=args.text_encoder_path,
|
||||||
|
|||||||
Reference in New Issue
Block a user