mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
support redirected tensor path
This commit is contained in:
@@ -126,7 +126,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class LightningModelForDataProcess(pl.LightningModule):
|
class LightningModelForDataProcess(pl.LightningModule):
|
||||||
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_path = [text_encoder_path, vae_path]
|
model_path = [text_encoder_path, vae_path]
|
||||||
if image_encoder_path is not None:
|
if image_encoder_path is not None:
|
||||||
@@ -136,6 +136,7 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
data = batch[0]
|
data = batch[0]
|
||||||
@@ -158,20 +159,33 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
else:
|
else:
|
||||||
image_emb = {}
|
image_emb = {}
|
||||||
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||||
|
if self.redirected_tensor_path is not None:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(self.redirected_tensor_path, path)
|
||||||
torch.save(data, path + ".tensors.pth")
|
torch.save(data, path + ".tensors.pth")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TensorDataset(torch.utils.data.Dataset):
|
class TensorDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, steps_per_epoch):
|
def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None):
|
||||||
metadata = pd.read_csv(metadata_path)
|
metadata = pd.read_csv(metadata_path)
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
print(len(self.path), "videos in metadata.")
|
print(len(self.path), "videos in metadata.")
|
||||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
if redirected_tensor_path is None:
|
||||||
|
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||||
|
else:
|
||||||
|
cached_path = []
|
||||||
|
for path in self.path:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(redirected_tensor_path, path)
|
||||||
|
if os.path.exists(path + ".tensors.pth"):
|
||||||
|
cached_path.append(path + ".tensors.pth")
|
||||||
|
self.path = cached_path
|
||||||
print(len(self.path), "tensors cached in metadata.")
|
print(len(self.path), "tensors cached in metadata.")
|
||||||
assert len(self.path) > 0
|
assert len(self.path) > 0
|
||||||
|
|
||||||
self.steps_per_epoch = steps_per_epoch
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
@@ -337,6 +351,12 @@ def parse_args():
|
|||||||
default="./",
|
default="./",
|
||||||
help="Path to save the model.",
|
help="Path to save the model.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--redirected_tensor_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to save cached tensors.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text_encoder_path",
|
"--text_encoder_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -542,6 +562,7 @@ def data_process(args):
|
|||||||
tiled=args.tiled,
|
tiled=args.tiled,
|
||||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||||
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
)
|
)
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
@@ -556,6 +577,7 @@ def train(args):
|
|||||||
args.dataset_path,
|
args.dataset_path,
|
||||||
os.path.join(args.dataset_path, "metadata.csv"),
|
os.path.join(args.dataset_path, "metadata.csv"),
|
||||||
steps_per_epoch=args.steps_per_epoch,
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|||||||
Reference in New Issue
Block a user