mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
wanx tiled encode
This commit is contained in:
@@ -608,21 +608,6 @@ class WanXVideoVAE(nn.Module):
|
|||||||
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
||||||
self.upsampling_factor = 8
|
self.upsampling_factor = 8
|
||||||
|
|
||||||
def encode(self, videos):
|
|
||||||
"""
|
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
|
||||||
for u in videos
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def single_decode(self, hidden_state, device):
|
|
||||||
hidden_state = hidden_state.to(device)
|
|
||||||
x = self.model.decode(hidden_state.unsqueeze(0), self.scale)
|
|
||||||
return x.float().clamp_(-1, 1).squeeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||||
x = torch.ones((length,))
|
x = torch.ones((length,))
|
||||||
@@ -647,7 +632,6 @@ class WanXVideoVAE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
||||||
hidden_states = hidden_states.unsqueeze(0)
|
|
||||||
_, _, T, H, W = hidden_states.shape
|
_, _, T, H, W = hidden_states.shape
|
||||||
size_h, size_w = tile_size
|
size_h, size_w = tile_size
|
||||||
stride_h, stride_w = tile_stride
|
stride_h, stride_w = tile_stride
|
||||||
@@ -664,7 +648,7 @@ class WanXVideoVAE(nn.Module):
|
|||||||
data_device = "cpu"
|
data_device = "cpu"
|
||||||
computation_device = device
|
computation_device = device
|
||||||
|
|
||||||
out_T = T*4-3
|
out_T = T * 4 - 3
|
||||||
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||||
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||||
|
|
||||||
@@ -695,18 +679,100 @@ class WanXVideoVAE(nn.Module):
|
|||||||
target_w: target_w + hidden_states_batch.shape[4],
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
] += mask
|
] += mask
|
||||||
values = values / weight
|
values = values / weight
|
||||||
values = values.float().clamp_(-1, 1).squeeze(0)
|
values = values.float().clamp_(-1, 1)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def tiled_encode(self, video, device, tile_size, tile_stride):
|
||||||
|
_, _, T, H, W = video.shape
|
||||||
|
size_h, size_w = tile_size
|
||||||
|
stride_h, stride_w = tile_stride
|
||||||
|
|
||||||
|
# Split tasks
|
||||||
|
tasks = []
|
||||||
|
for h in range(0, H, stride_h):
|
||||||
|
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||||
|
for w in range(0, W, stride_w):
|
||||||
|
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||||
|
h_, w_ = h + size_h, w + size_w
|
||||||
|
tasks.append((h, h_, w, w_))
|
||||||
|
|
||||||
|
data_device = "cpu"
|
||||||
|
computation_device = device
|
||||||
|
|
||||||
|
out_T = (T + 3) // 4
|
||||||
|
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||||
|
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||||
|
|
||||||
|
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||||
|
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
||||||
|
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
||||||
|
|
||||||
|
mask = self.build_mask(
|
||||||
|
hidden_states_batch,
|
||||||
|
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||||
|
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
||||||
|
).to(dtype=video.dtype, device=data_device)
|
||||||
|
|
||||||
|
target_h = h // self.upsampling_factor
|
||||||
|
target_w = w // self.upsampling_factor
|
||||||
|
values[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_h:target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w:target_w + hidden_states_batch.shape[4],
|
||||||
|
] += hidden_states_batch * mask
|
||||||
|
weight[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += mask
|
||||||
|
values = values / weight
|
||||||
|
values = values.float()
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def single_encode(self, video, device):
|
||||||
|
video = video.to(device)
|
||||||
|
x = self.model.encode(video, self.scale)
|
||||||
|
return x.float()
|
||||||
|
|
||||||
|
|
||||||
|
def single_decode(self, hidden_state, device):
|
||||||
|
hidden_state = hidden_state.to(device)
|
||||||
|
video = self.model.decode(hidden_state, self.scale)
|
||||||
|
return video.float().clamp_(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)):
|
||||||
|
|
||||||
|
videos = [video.to("cpu") for video in videos]
|
||||||
|
hidden_states = []
|
||||||
|
for video in videos:
|
||||||
|
video = video.unsqueeze(0)
|
||||||
|
if tiled:
|
||||||
|
assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}"
|
||||||
|
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||||
|
else:
|
||||||
|
hidden_state = self.single_encode(video, device)
|
||||||
|
hidden_state = hidden_state.squeeze(0)
|
||||||
|
hidden_states.append(hidden_state)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||||
videos = []
|
videos = []
|
||||||
for hidden_state in hidden_states:
|
for hidden_state in hidden_states:
|
||||||
|
hidden_state = hidden_state.unsqueeze(0)
|
||||||
if tiled:
|
if tiled:
|
||||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||||
else:
|
else:
|
||||||
video = self.single_decode(hidden_state, device)
|
video = self.single_decode(hidden_state, device)
|
||||||
|
video = video.squeeze(0)
|
||||||
videos.append(video)
|
videos.append(video)
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
|
|||||||
@@ -37,11 +37,10 @@ vae = model_manager.fetch_model('wanxvideo_vae')
|
|||||||
|
|
||||||
latents = [torch.load('sample.pt')]
|
latents = [torch.load('sample.pt')]
|
||||||
videos = vae.decode(latents, device=latents[0].device, tiled=True)
|
videos = vae.decode(latents, device=latents[0].device, tiled=True)
|
||||||
# back_encode = vae.encode(videos)
|
back_encode = vae.encode(videos, device=latents[0].device, tiled=True)
|
||||||
|
|
||||||
|
videos_back_encode = vae.decode(back_encode, device=latents[0].device, tiled=False)
|
||||||
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
|
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
|
||||||
|
|
||||||
save_video(videos[0][None], save_file='example3.mp4', fps=16, nrow=8)
|
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
|
||||||
print(latents)
|
save_video(videos_back_encode[0][None], save_file='example_backencode.mp4', fps=16, nrow=1)
|
||||||
print(videos)
|
|
||||||
# print(back_encode)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user