diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index ca199a2..cd10096 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -56,13 +56,16 @@ class TextVideoDataset(torch.utils.data.Dataset): frame = Image.fromarray(frame) frame = self.crop_and_resize(frame) if first_frame is None: - first_frame = np.array(frame) + first_frame = frame frame = frame_process(frame) frames.append(frame) reader.close() frames = torch.stack(frames, dim=0) frames = rearrange(frames, "T C H W -> C T H W") + + first_frame = v2.functional.center_crop(first_frame, output_size=(self.height, self.width)) + first_frame = np.array(first_frame) if self.is_i2v: return frames, first_frame