Update train_wan_t2v.py

This commit is contained in:
Zhongjie Duan
2025-04-17 15:37:30 +08:00
committed by GitHub
parent b36cad6929
commit bf81de0c88

View File

@@ -56,7 +56,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
frame = Image.fromarray(frame) frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame) frame = self.crop_and_resize(frame)
if first_frame is None: if first_frame is None:
first_frame = np.array(frame) first_frame = frame
frame = frame_process(frame) frame = frame_process(frame)
frames.append(frame) frames.append(frame)
reader.close() reader.close()
@@ -64,6 +64,9 @@ class TextVideoDataset(torch.utils.data.Dataset):
frames = torch.stack(frames, dim=0) frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W") frames = rearrange(frames, "T C H W -> C T H W")
first_frame = v2.functional.center_crop(first_frame, output_size=(self.height, self.width))
first_frame = np.array(first_frame)
if self.is_i2v: if self.is_i2v:
return frames, first_frame return frames, first_frame
else: else: