Files
web-realesrgan/esrgan-exporter.py
2024-10-17 21:21:57 +08:00

155 lines
5.1 KiB
Python

import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import torch
model2name = {
"realesr-animevideov3": "anime_fast",
"RealESRGAN_x4plus_anime_6B": "anime_plus",
"realesr-general-x4v3": "general_fast",
"RealESRGAN_x4plus": "general_plus",
}
def export(model_name, tile_size):
if model_name == "RealESRGAN_x4plus": # x4 RRDBNet model
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
file_url = [
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
]
elif model_name == "RealESRNet_x4plus": # x4 RRDBNet model
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
file_url = [
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
]
elif model_name == "RealESRGAN_x4plus_anime_6B": # x4 RRDBNet model with 6 blocks
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4
)
netscale = 4
file_url = [
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
]
elif model_name == "RealESRGAN_x2plus": # x2 RRDBNet model
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
netscale = 2
file_url = [
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
]
elif model_name == "realesr-animevideov3": # x4 VGG-style model (XS size)
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=16,
upscale=4,
act_type="prelu",
)
netscale = 4
file_url = [
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
]
elif model_name == "realesr-general-x4v3": # x4 VGG-style model (S size)
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
netscale = 4
file_url = [
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
]
model_path = os.path.join("weights", model_name + ".pth")
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
# model_path will be updated
model_path = load_file_from_url(
url=url,
model_dir=os.path.join(ROOT_DIR, "weights"),
progress=True,
file_name=None,
)
# use dni to control the denoise strength
dni_weight = None
denoise_strength = 0.5
if model_name == "realesr-general-x4v3" and denoise_strength != 1:
wdn_model_path = model_path.replace(
"realesr-general-x4v3", "realesr-general-wdn-x4v3"
)
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=0,
pre_pad=0,
half=False,
gpu_id=None,
)
name = model2name[model_name]
torch.onnx.export(
upsampler.model,
torch.randn(1, 3, tile_size, tile_size),
f"{name}.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
)
os.system(f"onnxsim {name}.onnx {name}.onnx")
os.system(
f"onnx2tf -i {name}.onnx -osd -ois input:1,3,{tile_size},{tile_size} -o {name}-{tile_size}-tf"
)
# export tfjs_converter = xxx
if not os.path.exists("tfjs"):
os.mkdir("tfjs")
os.system(
# f"{os.environ['tfjs_converter']} --quantize_float16 --input_format tf_saved_model --output_format tfjs_graph_model {name}-{tile_size}-tf tfjs/{name}-{tile_size}"
f"{os.environ['tfjs_converter']} --control_flow_v2=True --quantize_float16 --input_format tf_saved_model --output_format tfjs_graph_model {name}-{tile_size}-tf tfjs/{name}-{tile_size}"
)
for model_name in model2name.keys():
for tile_size in [32, 48, 64, 96, 128, 192, 256]:
export(model_name, tile_size)