mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
wanx vae tile decode
This commit is contained in:
@@ -16,7 +16,7 @@ def save_video(tensor,
|
||||
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
for u in tensor.unbind(2)
|
||||
],
|
||||
dim=1).permute(1, 2, 3, 0)
|
||||
dim=1).permute(1, 2, 3, 0) #frame, h, w, 3
|
||||
tensor = (tensor * 255).type(torch.uint8).cpu()
|
||||
|
||||
# write video
|
||||
@@ -26,6 +26,8 @@ def save_video(tensor,
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
||||
torch.cuda.memory._record_memory_history()
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.float, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/WanX/vae.pth",
|
||||
@@ -34,9 +36,12 @@ model_manager.load_models([
|
||||
vae = model_manager.fetch_model('wanxvideo_vae')
|
||||
|
||||
latents = [torch.load('sample.pt')]
|
||||
videos = vae.decode(latents)
|
||||
back_encode = vae.encode(videos)
|
||||
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
|
||||
videos = vae.decode(latents, device=latents[0].device, tiled=True)
|
||||
# back_encode = vae.encode(videos)
|
||||
|
||||
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
|
||||
|
||||
save_video(videos[0][None], save_file='example3.mp4', fps=16, nrow=8)
|
||||
print(latents)
|
||||
print(videos)
|
||||
print(back_encode)
|
||||
# print(back_encode)
|
||||
|
||||
Reference in New Issue
Block a user