diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index a96c0aa..817fd5c 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -23,8 +23,8 @@ class TextVideoDataset(torch.utils.data.Dataset): self.width = width self.frame_process = v2.Compose([ - v2.Resize(size=(height, width), antialias=True), v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), v2.ToTensor(), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) @@ -68,17 +68,28 @@ class TextVideoDataset(torch.utils.data.Dataset): return frames - def load_text_video_raw_data(self, data_id): - text = self.path[data_id] - video = self.load_video(self.path[data_id]) - data = {"text": text, "video": video} - return data + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame def __getitem__(self, data_id): text = self.text[data_id] path = self.path[data_id] - video = self.load_video(path) + if self.is_image(path): + video = self.load_image(path) + else: + video = self.load_video(path) data = {"text": text, "video": video, "path": path} return data