Merge pull request #360 from modelscope/wan-train-dev

support wan image training
This commit is contained in:
Zhongjie Duan
2025-02-27 12:58:32 +08:00
committed by GitHub

View File

@@ -23,8 +23,8 @@ class TextVideoDataset(torch.utils.data.Dataset):
self.width = width self.width = width
self.frame_process = v2.Compose([ self.frame_process = v2.Compose([
v2.Resize(size=(height, width), antialias=True),
v2.CenterCrop(size=(height, width)), v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(), v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]) ])
@@ -68,17 +68,28 @@ class TextVideoDataset(torch.utils.data.Dataset):
return frames return frames
def load_text_video_raw_data(self, data_id): def is_image(self, file_path):
text = self.path[data_id] file_ext_name = file_path.split(".")[-1]
video = self.load_video(self.path[data_id]) if file_ext_name.lower() in ["jpg", "png", "webp"]:
data = {"text": text, "video": video} return True
return data return False
def load_image(self, file_path):
frame = Image.open(file_path).convert("RGB")
frame = self.crop_and_resize(frame)
frame = self.frame_process(frame)
frame = rearrange(frame, "C H W -> C 1 H W")
return frame
def __getitem__(self, data_id): def __getitem__(self, data_id):
text = self.text[data_id] text = self.text[data_id]
path = self.path[data_id] path = self.path[data_id]
video = self.load_video(path) if self.is_image(path):
video = self.load_image(path)
else:
video = self.load_video(path)
data = {"text": text, "video": video, "path": path} data = {"text": text, "video": video, "path": path}
return data return data