mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
v1.2
This commit is contained in:
4
diffsynth/extensions/FastBlend/runners/__init__.py
Normal file
4
diffsynth/extensions/FastBlend/runners/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .accurate import AccurateModeRunner
|
||||
from .fast import FastModeRunner
|
||||
from .balanced import BalancedModeRunner
|
||||
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
||||
35
diffsynth/extensions/FastBlend/runners/accurate.py
Normal file
35
diffsynth/extensions/FastBlend/runners/accurate.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class AccurateModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
use_mean_target_style=True,
|
||||
**ebsynth_config
|
||||
)
|
||||
# run
|
||||
n = len(frames_style)
|
||||
for target in tqdm(range(n), desc=desc):
|
||||
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
||||
remapped_frames = []
|
||||
for i in range(l, r, batch_size):
|
||||
j = min(i + batch_size, r)
|
||||
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
||||
target_guide = np.stack([frames_guide[target]] * (j - i))
|
||||
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
remapped_frames.append(target_style)
|
||||
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
if save_path is not None:
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
||||
46
diffsynth/extensions/FastBlend/runners/balanced.py
Normal file
46
diffsynth/extensions/FastBlend/runners/balanced.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class BalancedModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
**ebsynth_config
|
||||
)
|
||||
# tasks
|
||||
n = len(frames_style)
|
||||
tasks = []
|
||||
for target in range(n):
|
||||
for source in range(target - window_size, target + window_size + 1):
|
||||
if source >= 0 and source < n and source != target:
|
||||
tasks.append((source, target))
|
||||
# run
|
||||
frames = [(None, 1) for i in range(n)]
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
||||
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
||||
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for (source, target), result in zip(tasks_batch, target_style):
|
||||
frame, weight = frames[target]
|
||||
if frame is None:
|
||||
frame = frames_style[target]
|
||||
frames[target] = (
|
||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
||||
weight + 1
|
||||
)
|
||||
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
if save_path is not None:
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
||||
frames[target] = (None, 1)
|
||||
141
diffsynth/extensions/FastBlend/runners/fast.py
Normal file
141
diffsynth/extensions/FastBlend/runners/fast.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import functools, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TableManager:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def task_list(self, n):
|
||||
tasks = []
|
||||
max_level = 1
|
||||
while (1<<max_level)<=n:
|
||||
max_level += 1
|
||||
for i in range(n):
|
||||
j = i
|
||||
for level in range(max_level):
|
||||
if i&(1<<level):
|
||||
continue
|
||||
j |= 1<<level
|
||||
if j>=n:
|
||||
break
|
||||
meta_data = {
|
||||
"source": i,
|
||||
"target": j,
|
||||
"level": level + 1
|
||||
}
|
||||
tasks.append(meta_data)
|
||||
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
||||
return tasks
|
||||
|
||||
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
||||
n = len(frames_guide)
|
||||
tasks = self.task_list(n)
|
||||
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
||||
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for task, result in zip(tasks_batch, target_style):
|
||||
target, level = task["target"], task["level"]
|
||||
if len(remapping_table[target])==level:
|
||||
remapping_table[target].append((result, 1))
|
||||
else:
|
||||
frame, weight = remapping_table[target][level]
|
||||
remapping_table[target][level] = (
|
||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
||||
weight + 1
|
||||
)
|
||||
return remapping_table
|
||||
|
||||
def remapping_table_to_blending_table(self, table):
|
||||
for i in range(len(table)):
|
||||
for j in range(1, len(table[i])):
|
||||
frame_1, weight_1 = table[i][j-1]
|
||||
frame_2, weight_2 = table[i][j]
|
||||
frame = (frame_1 + frame_2) / 2
|
||||
weight = weight_1 + weight_2
|
||||
table[i][j] = (frame, weight)
|
||||
return table
|
||||
|
||||
def tree_query(self, leftbound, rightbound):
|
||||
node_list = []
|
||||
node_index = rightbound
|
||||
while node_index>=leftbound:
|
||||
node_level = 0
|
||||
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
||||
node_level += 1
|
||||
node_list.append((node_index, node_level))
|
||||
node_index -= 1<<node_level
|
||||
return node_list
|
||||
|
||||
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
||||
n = len(blending_table)
|
||||
tasks = []
|
||||
frames_result = []
|
||||
for target in range(n):
|
||||
node_list = self.tree_query(max(target-window_size, 0), target)
|
||||
for source, level in node_list:
|
||||
if source!=target:
|
||||
meta_data = {
|
||||
"source": source,
|
||||
"target": target,
|
||||
"level": level
|
||||
}
|
||||
tasks.append(meta_data)
|
||||
else:
|
||||
frames_result.append(blending_table[target][level])
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
||||
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for task, frame_2 in zip(tasks_batch, target_style):
|
||||
source, target, level = task["source"], task["target"], task["level"]
|
||||
frame_1, weight_1 = frames_result[target]
|
||||
weight_2 = blending_table[source][level][1]
|
||||
weight = weight_1 + weight_2
|
||||
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
||||
frames_result[target] = (frame, weight)
|
||||
return frames_result
|
||||
|
||||
|
||||
class FastModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
||||
frames_guide = frames_guide.raw_data()
|
||||
frames_style = frames_style.raw_data()
|
||||
table_manager = TableManager()
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
**ebsynth_config
|
||||
)
|
||||
# left part
|
||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
||||
# right part
|
||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
||||
# merge
|
||||
frames = []
|
||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
||||
weight_m = -1
|
||||
weight = weight_l + weight_m + weight_r
|
||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
||||
frames.append(frame)
|
||||
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
||||
if save_path is not None:
|
||||
for target, frame in enumerate(frames):
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
||||
121
diffsynth/extensions/FastBlend/runners/interpolation.py
Normal file
121
diffsynth/extensions/FastBlend/runners/interpolation.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class InterpolationModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_index_dict(self, index_style):
|
||||
index_dict = {}
|
||||
for i, index in enumerate(index_style):
|
||||
index_dict[index] = i
|
||||
return index_dict
|
||||
|
||||
def get_weight(self, l, m, r):
|
||||
weight_l, weight_r = abs(m - r), abs(m - l)
|
||||
if weight_l + weight_r == 0:
|
||||
weight_l, weight_r = 0.5, 0.5
|
||||
else:
|
||||
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
||||
return weight_l, weight_r
|
||||
|
||||
def get_task_group(self, index_style, n):
|
||||
task_group = []
|
||||
index_style = sorted(index_style)
|
||||
# first frame
|
||||
if index_style[0]>0:
|
||||
tasks = []
|
||||
for m in range(index_style[0]):
|
||||
tasks.append((index_style[0], m, index_style[0]))
|
||||
task_group.append(tasks)
|
||||
# middle frames
|
||||
for l, r in zip(index_style[:-1], index_style[1:]):
|
||||
tasks = []
|
||||
for m in range(l, r):
|
||||
tasks.append((l, m, r))
|
||||
task_group.append(tasks)
|
||||
# last frame
|
||||
tasks = []
|
||||
for m in range(index_style[-1], n):
|
||||
tasks.append((index_style[-1], m, index_style[-1]))
|
||||
task_group.append(tasks)
|
||||
return task_group
|
||||
|
||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
use_mean_target_style=False,
|
||||
use_pairwise_patch_error=True,
|
||||
**ebsynth_config
|
||||
)
|
||||
# task
|
||||
index_dict = self.get_index_dict(index_style)
|
||||
task_group = self.get_task_group(index_style, len(frames_guide))
|
||||
# run
|
||||
for tasks in task_group:
|
||||
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide, target_guide, source_style = [], [], []
|
||||
for l, m, r in tasks_batch:
|
||||
# l -> m
|
||||
source_guide.append(frames_guide[l])
|
||||
target_guide.append(frames_guide[m])
|
||||
source_style.append(frames_style[index_dict[l]])
|
||||
# r -> m
|
||||
source_guide.append(frames_guide[r])
|
||||
target_guide.append(frames_guide[m])
|
||||
source_style.append(frames_style[index_dict[r]])
|
||||
source_guide = np.stack(source_guide)
|
||||
target_guide = np.stack(target_guide)
|
||||
source_style = np.stack(source_style)
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
if save_path is not None:
|
||||
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
||||
weight_l, weight_r = self.get_weight(l, m, r)
|
||||
frame = frame_l * weight_l + frame_r * weight_r
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
||||
|
||||
|
||||
class InterpolationModeSingleFrameRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
||||
# check input
|
||||
tracking_window_size = ebsynth_config["tracking_window_size"]
|
||||
if tracking_window_size * 2 >= batch_size:
|
||||
raise ValueError("batch_size should be larger than track_window_size * 2")
|
||||
frame_style = frames_style[0]
|
||||
frame_guide = frames_guide[index_style[0]]
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frame_style.shape[0],
|
||||
image_width=frame_style.shape[1],
|
||||
channel=3,
|
||||
**ebsynth_config
|
||||
)
|
||||
# run
|
||||
frame_id, n = 0, len(frames_guide)
|
||||
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
||||
if i + batch_size > n:
|
||||
l, r = max(n - batch_size, 0), n
|
||||
else:
|
||||
l, r = i, i + batch_size
|
||||
source_guide = np.stack([frame_guide] * (r-l))
|
||||
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
||||
source_style = np.stack([frame_style] * (r-l))
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for i, frame in zip(range(l, r), target_style):
|
||||
if i==frame_id:
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
||||
frame_id += 1
|
||||
if r < n and r-frame_id <= tracking_window_size:
|
||||
break
|
||||
Reference in New Issue
Block a user