This commit is contained in:
Artiprocher
2024-05-05 22:48:38 +08:00
parent cc37860438
commit 0965477750
15 changed files with 2991 additions and 79 deletions

View File

@@ -167,7 +167,7 @@ class RIFEInterpolater:
@torch.no_grad()
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1):
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
# Preprocess
processed_images = self.process_images(images)
@@ -177,7 +177,7 @@ class RIFEInterpolater:
# Interpolate
output_tensor = []
for batch_id in range(0, input_tensor.shape[0], batch_size):
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)