From a8ce9fef331c24792d91a7997c65be4cc9b45cd7 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 18 Mar 2025 19:24:27 +0800 Subject: [PATCH] support redirected tensor path --- examples/wanvideo/train_wan_t2v.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index f672eae..5f7fdf1 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -126,7 +126,7 @@ class TextVideoDataset(torch.utils.data.Dataset): class LightningModelForDataProcess(pl.LightningModule): - def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None): super().__init__() model_path = [text_encoder_path, vae_path] if image_encoder_path is not None: @@ -136,6 +136,7 @@ class LightningModelForDataProcess(pl.LightningModule): self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + self.redirected_tensor_path = redirected_tensor_path def test_step(self, batch, batch_idx): data = batch[0] @@ -158,20 +159,33 @@ class LightningModelForDataProcess(pl.LightningModule): else: image_emb = {} data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + if self.redirected_tensor_path is not None: + path = path.replace("/", "_").replace("\\", "_") + path = os.path.join(self.redirected_tensor_path, path) torch.save(data, path + ".tensors.pth") class TensorDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, steps_per_epoch): + def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None): metadata = pd.read_csv(metadata_path) self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] print(len(self.path), "videos in metadata.") - self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + if redirected_tensor_path is None: + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + else: + cached_path = [] + for path in self.path: + path = path.replace("/", "_").replace("\\", "_") + path = os.path.join(redirected_tensor_path, path) + if os.path.exists(path + ".tensors.pth"): + cached_path.append(path + ".tensors.pth") + self.path = cached_path print(len(self.path), "tensors cached in metadata.") assert len(self.path) > 0 self.steps_per_epoch = steps_per_epoch + self.redirected_tensor_path = redirected_tensor_path def __getitem__(self, index): @@ -337,6 +351,12 @@ def parse_args(): default="./", help="Path to save the model.", ) + parser.add_argument( + "--redirected_tensor_path", + type=str, + default=None, + help="Path to save cached tensors.", + ) parser.add_argument( "--text_encoder_path", type=str, @@ -542,6 +562,7 @@ def data_process(args): tiled=args.tiled, tile_size=(args.tile_size_height, args.tile_size_width), tile_stride=(args.tile_stride_height, args.tile_stride_width), + redirected_tensor_path=args.redirected_tensor_path, ) trainer = pl.Trainer( accelerator="gpu", @@ -556,6 +577,7 @@ def train(args): args.dataset_path, os.path.join(args.dataset_path, "metadata.csv"), steps_per_epoch=args.steps_per_epoch, + redirected_tensor_path=args.redirected_tensor_path, ) dataloader = torch.utils.data.DataLoader( dataset,