optimize performance

This commit is contained in:
xororz
2024-04-15 04:15:21 -04:00
parent c2c5cb5aaf
commit e8bc3dc016
6 changed files with 101 additions and 189 deletions

View File

@@ -208,6 +208,7 @@ export default {
worker: new Worker(new URL("./worker.js", import.meta.url), {
type: "module",
}),
wasmModule: null,
};
},
watch: {
@@ -271,19 +272,19 @@ export default {
this.drawLine = true;
let wasmModule = await Module();
this.wasmModule = wasmModule;
const imgCanvas = this.$refs.imgCanvas;
imgCanvas.width = this.img.width;
imgCanvas.height = this.img.height;
const imgCtx = imgCanvas.getContext("2d");
imgCtx.drawImage(this.img, 0, 0);
this.input = new Img(this.img.width, this.img.height);
let data = imgCtx.getImageData(
0,
0,
this.img.width,
this.img.height
).data;
this.input.data = new Uint8Array(data);
this.input = new Img(this.img.width, this.img.height, data);
const numPixels = this.input.width * this.input.height;
const bytesPerImage = numPixels * 4;
let sourcePtr = wasmModule._malloc(bytesPerImage);
@@ -292,7 +293,6 @@ export default {
this.hasAlpha = wasmModule._check_alpha(sourcePtr, numPixels);
if (this.hasAlpha) {
this.inputAlpha = new Img(this.img.width, this.img.height);
this.inputAlpha.data = new Uint8Array(bytesPerImage);
wasmModule._copy_alpha_to_rgb(sourcePtr, targetPtr, numPixels);
this.inputAlpha.data.set(
wasmModule.HEAPU8.subarray(targetPtr, targetPtr + bytesPerImage)
@@ -610,28 +610,46 @@ export default {
this.progress = progress;
if (done) {
if (!this.hasAlpha || (this.hasAlpha && this.inputAlpha)) {
this.output = output;
let factor = this.modelzoo[this.model].factor;
this.output = new Img(
factor * this.input.width,
factor * this.input.height,
new Uint8ClampedArray(output)
);
}
this.info = "Processing Image...";
if (this.inputAlpha) {
worker.postMessage({
input: this.inputAlpha.data,
fixed: this.modelzoo[this.model].fixed,
factor: this.modelzoo[this.model].factor,
width: this.inputAlpha.width,
height: this.inputAlpha.height,
model: this.model,
backend: this.backend,
hasAlpha: true,
});
worker.postMessage(
{
input: this.inputAlpha.data.buffer,
fixed: this.modelzoo[this.model].fixed,
factor: this.modelzoo[this.model].factor,
width: this.inputAlpha.width,
height: this.inputAlpha.height,
model: this.model,
backend: this.backend,
hasAlpha: true,
},
[this.inputAlpha.data.buffer]
);
this.inputAlpha = null;
return;
}
if (this.hasAlpha) {
for (let i = 0; i < output.data.length; i += 4) {
if (output.data[i] < 128) this.output.data[i + 3] = 0;
else this.output.data[i + 3] = 255;
}
let outputArray = new Uint8Array(output);
let wasmModule = this.wasmModule;
let sourcePtr = wasmModule._malloc(outputArray.length);
let targetPtr = wasmModule._malloc(outputArray.length);
let numPixels = outputArray.length / 4;
wasmModule.HEAPU8.set(outputArray, sourcePtr);
wasmModule.HEAPU8.set(this.output.data, targetPtr);
wasmModule._copy_alpha_channel(sourcePtr, targetPtr, numPixels);
this.output.data.set(
wasmModule.HEAPU8.subarray(
targetPtr,
targetPtr + outputArray.length
)
);
}
const imgCanvas = this.$refs.imgCanvas;
@@ -639,13 +657,22 @@ export default {
imgCtx.clearRect(0, 0, imgCanvas.width, imgCanvas.height);
imgCanvas.width = this.output.width;
imgCanvas.height = this.output.height;
let outImg = imgCtx.createImageData(output.width, output.height);
let outImg = imgCtx.createImageData(
this.output.width,
this.output.height
);
outImg.data.set(this.output.data);
imgCtx.putImageData(outImg, 0, 0);
let type = "image/jpeg";
let quality = 0.92;
if (this.hasAlpha) type = "image/png";
this.processedImg.src = imgCanvas.toDataURL(type, quality);
this.processedImg.src = imgCanvas.toBlob(
(blob) => {
this.processedImg.src = URL.createObjectURL(blob);
},
type,
quality
);
this.processedImg.onload = () => {
this.linePosition = this.$refs.canvas.width * 0.5;
this.$refs.dragLine.style.left =
@@ -658,16 +685,19 @@ export default {
worker.terminate();
}
});
worker.postMessage({
input: this.input.data,
fixed: this.modelzoo[this.model].fixed,
factor: this.modelzoo[this.model].factor,
width: this.input.width,
height: this.input.height,
model: this.model,
backend: this.backend,
hasAlpha: false,
});
worker.postMessage(
{
input: this.input.data.buffer,
fixed: this.modelzoo[this.model].fixed,
factor: this.modelzoo[this.model].factor,
width: this.input.width,
height: this.input.height,
model: this.model,
backend: this.backend,
hasAlpha: false,
},
[this.input.data.buffer]
);
},
saveImage() {
const a = document.createElement("a");

View File

@@ -2,10 +2,14 @@ export default class Image {
width: number;
height: number;
data: Uint8Array;
constructor(width: number, height: number) {
constructor(
width: number,
height: number,
data = new Uint8Array(width * height * 4)
) {
this.width = width;
this.height = height;
this.data = new Uint8Array(width * height * 4);
this.data = data;
}
getImageCrop(
x: number,

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -3,11 +3,15 @@ import Image from "./image";
export default async function upscale(
image: Image,
model: any
model: tf.GraphModel,
alpha = false
): Promise<Image> {
const result = tf.tidy(() => {
const tensor = img2tensor(image);
const result = model.predict(tensor) as tf.Tensor;
let result = model.predict(tensor) as tf.Tensor;
if (alpha) {
result = tf.greater(result, 0.5);
}
return result;
});
const resultImage = await tensor2img(result);
@@ -16,32 +20,26 @@ export default async function upscale(
}
function img2tensor(image: Image): tf.Tensor {
let arr = new Float32Array(image.width * image.height * 3);
for (let i = 0; i < image.width * image.height; i++) {
arr[i * 3] = image.data[i * 4] / 255;
arr[i * 3 + 1] = image.data[i * 4 + 1] / 255;
arr[i * 3 + 2] = image.data[i * 4 + 2] / 255;
}
let tensor = tf.tensor4d(arr, [1, image.height, image.width, 3]);
let imgdata = new ImageData(image.width, image.height);
imgdata.data.set(image.data);
let tensor = tf.browser.fromPixels(imgdata).div(255).toFloat().expandDims();
return tensor;
}
async function tensor2img(tensor: tf.Tensor): Promise<Image> {
let [_, height, width, __] = tensor.shape;
let arr = await tensor.data();
tensor.dispose();
let clipped = new Uint8Array(
arr.map((x) => {
x = Math.min(1, Math.max(0, x));
return Math.floor(x * 255);
})
let clipped = tf.tidy(() =>
tensor
.reshape([height, width, 3])
.mul(255)
.cast("int32")
.clipByValue(0, 255)
);
let image = new Image(width, height);
for (let i = 0; i < width * height; i++) {
image.data[i * 4] = clipped[i * 3];
image.data[i * 4 + 1] = clipped[i * 3 + 1];
image.data[i * 4 + 2] = clipped[i * 3 + 2];
image.data[i * 4 + 3] = 255;
}
tensor.dispose();
let data = await tf.browser.toPixels(clipped as tf.Tensor3D);
clipped.dispose();
let image = new Image(width, height, data as unknown as Uint8Array);
return image;
}

View File

@@ -41,8 +41,7 @@ self.addEventListener("message", async (e) => {
if (!model) {
return;
}
const input = new Img(data.width, data.height);
input.data = data.input;
const input = new Img(data.width, data.height, new Uint8Array(data.input));
let hasAlpha = data.hasAlpha;
function sendprogress(progress) {
if (hasAlpha) {
@@ -57,126 +56,6 @@ self.addEventListener("message", async (e) => {
info: `Processing ${progress.toFixed(2)}%`,
});
}
async function enlargeImage(
model,
inputImg,
factor = 4,
tilesize = 32,
padsize = 8
) {
if (hasAlpha) {
tilesize = 16;
padsize = 4;
}
const width = inputImg.width;
const height = inputImg.height;
const output = new Img(width * factor, height * factor);
const total = Math.ceil(width / tilesize) * Math.ceil(height / tilesize);
let current = 0;
let useModel = new Array(total).fill(false);
if (hasAlpha) {
for (let i = 0; i < width; i += tilesize) {
for (let j = 0; j < height; j += tilesize) {
const x1 = Math.max(i, 0);
const y1 = Math.max(j, 0);
const x2 = Math.min(i + tilesize, width);
const y2 = Math.min(j + tilesize, height);
const tile = new Img(x2 - x1, y2 - y1);
tile.getImageCrop(0, 0, input, x1, y1, x2, y2);
for (let k = 4; k < tile.data.length; k += 4) {
if (tile.data[k + 3] !== tile.data[3]) {
useModel[current] = true;
break;
}
}
if (useModel[current]) {
current++;
continue;
}
let scaled = new Img(tile.width * factor, tile.height * factor);
for (let k = 0; k < scaled.data.length; k += 4) {
scaled.data[k] = tile.data[3];
scaled.data[k + 1] = tile.data[3];
scaled.data[k + 2] = tile.data[3];
}
output.getImageCrop(
i * factor,
j * factor,
scaled,
0,
0,
scaled.width,
scaled.height
);
current++;
}
}
current = 0;
for (let i = 0; i < width; i += tilesize) {
for (let j = 0; j < height; j += tilesize) {
if (!useModel[current]) {
current++;
let progress = (current / total) * 100;
sendprogress(progress);
continue;
}
const x1 = Math.max(i - padsize, 0);
const y1 = Math.max(j - padsize, 0);
const x2 = Math.min(i + tilesize + padsize, width);
const y2 = Math.min(j + tilesize + padsize, height);
const pad_left = i - x1;
const pad_top = j - y1;
const pad_right = Math.max(0, x2 - (i + tilesize));
const pad_bottom = Math.max(0, y2 - (j + tilesize));
const tile = new Img(x2 - x1, y2 - y1);
tile.getImageCrop(0, 0, input, x1, y1, x2, y2);
let scaled = await upscale(tile, model);
output.getImageCrop(
i * factor,
j * factor,
scaled,
pad_left * factor,
pad_top * factor,
scaled.width - pad_right * factor,
scaled.height - pad_bottom * factor
);
// console.log(i, j, x2 - x1, y2 - y1);
current++;
let progress = (current / total) * 100;
sendprogress(progress);
}
}
} else {
for (let i = 0; i < width; i += tilesize) {
for (let j = 0; j < height; j += tilesize) {
const x1 = Math.max(i - padsize, 0);
const y1 = Math.max(j - padsize, 0);
const x2 = Math.min(i + tilesize + padsize, width);
const y2 = Math.min(j + tilesize + padsize, height);
const pad_left = i - x1;
const pad_top = j - y1;
const pad_right = Math.max(0, x2 - (i + tilesize));
const pad_bottom = Math.max(0, y2 - (j + tilesize));
const tile = new Img(x2 - x1, y2 - y1);
tile.getImageCrop(0, 0, input, x1, y1, x2, y2);
let scaled = await upscale(tile, model);
output.getImageCrop(
i * factor,
j * factor,
scaled,
pad_left * factor,
pad_top * factor,
scaled.width - pad_right * factor,
scaled.height - pad_bottom * factor
);
current++;
let progress = (current / total) * 100;
sendprogress(progress);
}
}
}
return output;
}
async function enlargeImageWithFixedInput(
model,
inputImg,
@@ -273,7 +152,6 @@ self.addEventListener("message", async (e) => {
scaled.width - pad_right[i] * factor,
scaled.height - pad_bottom[j] * factor
);
// console.log(i, j, x2 - x1, y2 - y1);
current++;
}
}
@@ -292,7 +170,7 @@ self.addEventListener("message", async (e) => {
const y2 = locs_y[j] + input_size;
const tile = new Img(input_size, input_size);
tile.getImageCrop(0, 0, inputImg, x1, y1, x2, y2);
let scaled = await upscale(tile, model);
let scaled = await upscale(tile, model, true);
output.getImageCrop(
(x1 + pad_left[i]) * factor,
(y1 + pad_top[j]) * factor,
@@ -302,7 +180,6 @@ self.addEventListener("message", async (e) => {
scaled.width - pad_right[i] * factor,
scaled.height - pad_bottom[j] * factor
);
// console.log(i, j, x2 - x1, y2 - y1);
current++;
let progress = (current / total) * 100;
sendprogress(progress);
@@ -327,7 +204,6 @@ self.addEventListener("message", async (e) => {
scaled.width - pad_right[i] * factor,
scaled.height - pad_bottom[j] * factor
);
// console.log(i, j, x2 - x1, y2 - y1);
current++;
let progress = (current / total) * 100;
sendprogress(progress);
@@ -340,21 +216,25 @@ self.addEventListener("message", async (e) => {
const start = Date.now();
let output;
try {
// if (data?.fixed) {
// output = await enlargeImageWithFixedInput(model, input, factor);
// } else {
// output = await enlargeImage(model, input, factor);
// }
output = await enlargeImageWithFixedInput(model, input, factor);
} catch (e) {
postMessage({ alertmsg: e.toString() });
}
const end = Date.now();
console.log("Time:", end - start);
await new Promise((resolve) => setTimeout(resolve, 10));
postMessage({
progress: 100,
done: true,
output: output,
info: `Processing image...`,
});
postMessage(
{
progress: 100,
done: true,
output: output.data.buffer,
info: `Processing image...`,
},
[output.data.buffer]
);
console.log("output");
});