wanx tiled encode

This commit is contained in:
mi804
2025-02-21 12:58:45 +08:00
parent 02a4c8df9f
commit 2cefc20ed6
2 changed files with 88 additions and 23 deletions

View File

@@ -608,21 +608,6 @@ class WanXVideoVAE(nn.Module):
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
self.upsampling_factor = 8
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
x = self.model.decode(hidden_state.unsqueeze(0), self.scale)
return x.float().clamp_(-1, 1).squeeze(0)
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
@@ -647,7 +632,6 @@ class WanXVideoVAE(nn.Module):
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
hidden_states = hidden_states.unsqueeze(0)
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
@@ -664,7 +648,7 @@ class WanXVideoVAE(nn.Module):
data_device = "cpu"
computation_device = device
out_T = T*4-3
out_T = T * 4 - 3
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
@@ -695,18 +679,100 @@ class WanXVideoVAE(nn.Module):
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float().clamp_(-1, 1).squeeze(0)
values = values.float().clamp_(-1, 1)
return values
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float()
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x.float()
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.float().clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}"
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
return hidden_states
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
return videos

View File

@@ -37,11 +37,10 @@ vae = model_manager.fetch_model('wanxvideo_vae')
latents = [torch.load('sample.pt')]
videos = vae.decode(latents, device=latents[0].device, tiled=True)
# back_encode = vae.encode(videos)
back_encode = vae.encode(videos, device=latents[0].device, tiled=True)
videos_back_encode = vae.decode(back_encode, device=latents[0].device, tiled=False)
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
save_video(videos[0][None], save_file='example3.mp4', fps=16, nrow=8)
print(latents)
print(videos)
# print(back_encode)
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
save_video(videos_back_encode[0][None], save_file='example_backencode.mp4', fps=16, nrow=1)