optimize performance

This commit is contained in:
xororz
2024-04-15 04:15:21 -04:00
parent c2c5cb5aaf
commit e8bc3dc016
6 changed files with 101 additions and 189 deletions

View File

@@ -208,6 +208,7 @@ export default {
worker: new Worker(new URL("./worker.js", import.meta.url), { worker: new Worker(new URL("./worker.js", import.meta.url), {
type: "module", type: "module",
}), }),
wasmModule: null,
}; };
}, },
watch: { watch: {
@@ -271,19 +272,19 @@ export default {
this.drawLine = true; this.drawLine = true;
let wasmModule = await Module(); let wasmModule = await Module();
this.wasmModule = wasmModule;
const imgCanvas = this.$refs.imgCanvas; const imgCanvas = this.$refs.imgCanvas;
imgCanvas.width = this.img.width; imgCanvas.width = this.img.width;
imgCanvas.height = this.img.height; imgCanvas.height = this.img.height;
const imgCtx = imgCanvas.getContext("2d"); const imgCtx = imgCanvas.getContext("2d");
imgCtx.drawImage(this.img, 0, 0); imgCtx.drawImage(this.img, 0, 0);
this.input = new Img(this.img.width, this.img.height);
let data = imgCtx.getImageData( let data = imgCtx.getImageData(
0, 0,
0, 0,
this.img.width, this.img.width,
this.img.height this.img.height
).data; ).data;
this.input.data = new Uint8Array(data); this.input = new Img(this.img.width, this.img.height, data);
const numPixels = this.input.width * this.input.height; const numPixels = this.input.width * this.input.height;
const bytesPerImage = numPixels * 4; const bytesPerImage = numPixels * 4;
let sourcePtr = wasmModule._malloc(bytesPerImage); let sourcePtr = wasmModule._malloc(bytesPerImage);
@@ -292,7 +293,6 @@ export default {
this.hasAlpha = wasmModule._check_alpha(sourcePtr, numPixels); this.hasAlpha = wasmModule._check_alpha(sourcePtr, numPixels);
if (this.hasAlpha) { if (this.hasAlpha) {
this.inputAlpha = new Img(this.img.width, this.img.height); this.inputAlpha = new Img(this.img.width, this.img.height);
this.inputAlpha.data = new Uint8Array(bytesPerImage);
wasmModule._copy_alpha_to_rgb(sourcePtr, targetPtr, numPixels); wasmModule._copy_alpha_to_rgb(sourcePtr, targetPtr, numPixels);
this.inputAlpha.data.set( this.inputAlpha.data.set(
wasmModule.HEAPU8.subarray(targetPtr, targetPtr + bytesPerImage) wasmModule.HEAPU8.subarray(targetPtr, targetPtr + bytesPerImage)
@@ -610,28 +610,46 @@ export default {
this.progress = progress; this.progress = progress;
if (done) { if (done) {
if (!this.hasAlpha || (this.hasAlpha && this.inputAlpha)) { if (!this.hasAlpha || (this.hasAlpha && this.inputAlpha)) {
this.output = output; let factor = this.modelzoo[this.model].factor;
this.output = new Img(
factor * this.input.width,
factor * this.input.height,
new Uint8ClampedArray(output)
);
} }
this.info = "Processing Image..."; this.info = "Processing Image...";
if (this.inputAlpha) { if (this.inputAlpha) {
worker.postMessage({ worker.postMessage(
input: this.inputAlpha.data, {
fixed: this.modelzoo[this.model].fixed, input: this.inputAlpha.data.buffer,
factor: this.modelzoo[this.model].factor, fixed: this.modelzoo[this.model].fixed,
width: this.inputAlpha.width, factor: this.modelzoo[this.model].factor,
height: this.inputAlpha.height, width: this.inputAlpha.width,
model: this.model, height: this.inputAlpha.height,
backend: this.backend, model: this.model,
hasAlpha: true, backend: this.backend,
}); hasAlpha: true,
},
[this.inputAlpha.data.buffer]
);
this.inputAlpha = null; this.inputAlpha = null;
return; return;
} }
if (this.hasAlpha) { if (this.hasAlpha) {
for (let i = 0; i < output.data.length; i += 4) { let outputArray = new Uint8Array(output);
if (output.data[i] < 128) this.output.data[i + 3] = 0; let wasmModule = this.wasmModule;
else this.output.data[i + 3] = 255; let sourcePtr = wasmModule._malloc(outputArray.length);
} let targetPtr = wasmModule._malloc(outputArray.length);
let numPixels = outputArray.length / 4;
wasmModule.HEAPU8.set(outputArray, sourcePtr);
wasmModule.HEAPU8.set(this.output.data, targetPtr);
wasmModule._copy_alpha_channel(sourcePtr, targetPtr, numPixels);
this.output.data.set(
wasmModule.HEAPU8.subarray(
targetPtr,
targetPtr + outputArray.length
)
);
} }
const imgCanvas = this.$refs.imgCanvas; const imgCanvas = this.$refs.imgCanvas;
@@ -639,13 +657,22 @@ export default {
imgCtx.clearRect(0, 0, imgCanvas.width, imgCanvas.height); imgCtx.clearRect(0, 0, imgCanvas.width, imgCanvas.height);
imgCanvas.width = this.output.width; imgCanvas.width = this.output.width;
imgCanvas.height = this.output.height; imgCanvas.height = this.output.height;
let outImg = imgCtx.createImageData(output.width, output.height); let outImg = imgCtx.createImageData(
this.output.width,
this.output.height
);
outImg.data.set(this.output.data); outImg.data.set(this.output.data);
imgCtx.putImageData(outImg, 0, 0); imgCtx.putImageData(outImg, 0, 0);
let type = "image/jpeg"; let type = "image/jpeg";
let quality = 0.92; let quality = 0.92;
if (this.hasAlpha) type = "image/png"; if (this.hasAlpha) type = "image/png";
this.processedImg.src = imgCanvas.toDataURL(type, quality); this.processedImg.src = imgCanvas.toBlob(
(blob) => {
this.processedImg.src = URL.createObjectURL(blob);
},
type,
quality
);
this.processedImg.onload = () => { this.processedImg.onload = () => {
this.linePosition = this.$refs.canvas.width * 0.5; this.linePosition = this.$refs.canvas.width * 0.5;
this.$refs.dragLine.style.left = this.$refs.dragLine.style.left =
@@ -658,16 +685,19 @@ export default {
worker.terminate(); worker.terminate();
} }
}); });
worker.postMessage({ worker.postMessage(
input: this.input.data, {
fixed: this.modelzoo[this.model].fixed, input: this.input.data.buffer,
factor: this.modelzoo[this.model].factor, fixed: this.modelzoo[this.model].fixed,
width: this.input.width, factor: this.modelzoo[this.model].factor,
height: this.input.height, width: this.input.width,
model: this.model, height: this.input.height,
backend: this.backend, model: this.model,
hasAlpha: false, backend: this.backend,
}); hasAlpha: false,
},
[this.input.data.buffer]
);
}, },
saveImage() { saveImage() {
const a = document.createElement("a"); const a = document.createElement("a");

View File

@@ -2,10 +2,14 @@ export default class Image {
width: number; width: number;
height: number; height: number;
data: Uint8Array; data: Uint8Array;
constructor(width: number, height: number) { constructor(
width: number,
height: number,
data = new Uint8Array(width * height * 4)
) {
this.width = width; this.width = width;
this.height = height; this.height = height;
this.data = new Uint8Array(width * height * 4); this.data = data;
} }
getImageCrop( getImageCrop(
x: number, x: number,

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -3,11 +3,15 @@ import Image from "./image";
export default async function upscale( export default async function upscale(
image: Image, image: Image,
model: any model: tf.GraphModel,
alpha = false
): Promise<Image> { ): Promise<Image> {
const result = tf.tidy(() => { const result = tf.tidy(() => {
const tensor = img2tensor(image); const tensor = img2tensor(image);
const result = model.predict(tensor) as tf.Tensor; let result = model.predict(tensor) as tf.Tensor;
if (alpha) {
result = tf.greater(result, 0.5);
}
return result; return result;
}); });
const resultImage = await tensor2img(result); const resultImage = await tensor2img(result);
@@ -16,32 +20,26 @@ export default async function upscale(
} }
function img2tensor(image: Image): tf.Tensor { function img2tensor(image: Image): tf.Tensor {
let arr = new Float32Array(image.width * image.height * 3); let imgdata = new ImageData(image.width, image.height);
for (let i = 0; i < image.width * image.height; i++) { imgdata.data.set(image.data);
arr[i * 3] = image.data[i * 4] / 255; let tensor = tf.browser.fromPixels(imgdata).div(255).toFloat().expandDims();
arr[i * 3 + 1] = image.data[i * 4 + 1] / 255;
arr[i * 3 + 2] = image.data[i * 4 + 2] / 255;
}
let tensor = tf.tensor4d(arr, [1, image.height, image.width, 3]);
return tensor; return tensor;
} }
async function tensor2img(tensor: tf.Tensor): Promise<Image> { async function tensor2img(tensor: tf.Tensor): Promise<Image> {
let [_, height, width, __] = tensor.shape; let [_, height, width, __] = tensor.shape;
let arr = await tensor.data();
tensor.dispose(); let clipped = tf.tidy(() =>
let clipped = new Uint8Array( tensor
arr.map((x) => { .reshape([height, width, 3])
x = Math.min(1, Math.max(0, x)); .mul(255)
return Math.floor(x * 255); .cast("int32")
}) .clipByValue(0, 255)
); );
let image = new Image(width, height); tensor.dispose();
for (let i = 0; i < width * height; i++) { let data = await tf.browser.toPixels(clipped as tf.Tensor3D);
image.data[i * 4] = clipped[i * 3]; clipped.dispose();
image.data[i * 4 + 1] = clipped[i * 3 + 1]; let image = new Image(width, height, data as unknown as Uint8Array);
image.data[i * 4 + 2] = clipped[i * 3 + 2];
image.data[i * 4 + 3] = 255;
}
return image; return image;
} }

View File

@@ -41,8 +41,7 @@ self.addEventListener("message", async (e) => {
if (!model) { if (!model) {
return; return;
} }
const input = new Img(data.width, data.height); const input = new Img(data.width, data.height, new Uint8Array(data.input));
input.data = data.input;
let hasAlpha = data.hasAlpha; let hasAlpha = data.hasAlpha;
function sendprogress(progress) { function sendprogress(progress) {
if (hasAlpha) { if (hasAlpha) {
@@ -57,126 +56,6 @@ self.addEventListener("message", async (e) => {
info: `Processing ${progress.toFixed(2)}%`, info: `Processing ${progress.toFixed(2)}%`,
}); });
} }
async function enlargeImage(
model,
inputImg,
factor = 4,
tilesize = 32,
padsize = 8
) {
if (hasAlpha) {
tilesize = 16;
padsize = 4;
}
const width = inputImg.width;
const height = inputImg.height;
const output = new Img(width * factor, height * factor);
const total = Math.ceil(width / tilesize) * Math.ceil(height / tilesize);
let current = 0;
let useModel = new Array(total).fill(false);
if (hasAlpha) {
for (let i = 0; i < width; i += tilesize) {
for (let j = 0; j < height; j += tilesize) {
const x1 = Math.max(i, 0);
const y1 = Math.max(j, 0);
const x2 = Math.min(i + tilesize, width);
const y2 = Math.min(j + tilesize, height);
const tile = new Img(x2 - x1, y2 - y1);
tile.getImageCrop(0, 0, input, x1, y1, x2, y2);
for (let k = 4; k < tile.data.length; k += 4) {
if (tile.data[k + 3] !== tile.data[3]) {
useModel[current] = true;
break;
}
}
if (useModel[current]) {
current++;
continue;
}
let scaled = new Img(tile.width * factor, tile.height * factor);
for (let k = 0; k < scaled.data.length; k += 4) {
scaled.data[k] = tile.data[3];
scaled.data[k + 1] = tile.data[3];
scaled.data[k + 2] = tile.data[3];
}
output.getImageCrop(
i * factor,
j * factor,
scaled,
0,
0,
scaled.width,
scaled.height
);
current++;
}
}
current = 0;
for (let i = 0; i < width; i += tilesize) {
for (let j = 0; j < height; j += tilesize) {
if (!useModel[current]) {
current++;
let progress = (current / total) * 100;
sendprogress(progress);
continue;
}
const x1 = Math.max(i - padsize, 0);
const y1 = Math.max(j - padsize, 0);
const x2 = Math.min(i + tilesize + padsize, width);
const y2 = Math.min(j + tilesize + padsize, height);
const pad_left = i - x1;
const pad_top = j - y1;
const pad_right = Math.max(0, x2 - (i + tilesize));
const pad_bottom = Math.max(0, y2 - (j + tilesize));
const tile = new Img(x2 - x1, y2 - y1);
tile.getImageCrop(0, 0, input, x1, y1, x2, y2);
let scaled = await upscale(tile, model);
output.getImageCrop(
i * factor,
j * factor,
scaled,
pad_left * factor,
pad_top * factor,
scaled.width - pad_right * factor,
scaled.height - pad_bottom * factor
);
// console.log(i, j, x2 - x1, y2 - y1);
current++;
let progress = (current / total) * 100;
sendprogress(progress);
}
}
} else {
for (let i = 0; i < width; i += tilesize) {
for (let j = 0; j < height; j += tilesize) {
const x1 = Math.max(i - padsize, 0);
const y1 = Math.max(j - padsize, 0);
const x2 = Math.min(i + tilesize + padsize, width);
const y2 = Math.min(j + tilesize + padsize, height);
const pad_left = i - x1;
const pad_top = j - y1;
const pad_right = Math.max(0, x2 - (i + tilesize));
const pad_bottom = Math.max(0, y2 - (j + tilesize));
const tile = new Img(x2 - x1, y2 - y1);
tile.getImageCrop(0, 0, input, x1, y1, x2, y2);
let scaled = await upscale(tile, model);
output.getImageCrop(
i * factor,
j * factor,
scaled,
pad_left * factor,
pad_top * factor,
scaled.width - pad_right * factor,
scaled.height - pad_bottom * factor
);
current++;
let progress = (current / total) * 100;
sendprogress(progress);
}
}
}
return output;
}
async function enlargeImageWithFixedInput( async function enlargeImageWithFixedInput(
model, model,
inputImg, inputImg,
@@ -273,7 +152,6 @@ self.addEventListener("message", async (e) => {
scaled.width - pad_right[i] * factor, scaled.width - pad_right[i] * factor,
scaled.height - pad_bottom[j] * factor scaled.height - pad_bottom[j] * factor
); );
// console.log(i, j, x2 - x1, y2 - y1);
current++; current++;
} }
} }
@@ -292,7 +170,7 @@ self.addEventListener("message", async (e) => {
const y2 = locs_y[j] + input_size; const y2 = locs_y[j] + input_size;
const tile = new Img(input_size, input_size); const tile = new Img(input_size, input_size);
tile.getImageCrop(0, 0, inputImg, x1, y1, x2, y2); tile.getImageCrop(0, 0, inputImg, x1, y1, x2, y2);
let scaled = await upscale(tile, model); let scaled = await upscale(tile, model, true);
output.getImageCrop( output.getImageCrop(
(x1 + pad_left[i]) * factor, (x1 + pad_left[i]) * factor,
(y1 + pad_top[j]) * factor, (y1 + pad_top[j]) * factor,
@@ -302,7 +180,6 @@ self.addEventListener("message", async (e) => {
scaled.width - pad_right[i] * factor, scaled.width - pad_right[i] * factor,
scaled.height - pad_bottom[j] * factor scaled.height - pad_bottom[j] * factor
); );
// console.log(i, j, x2 - x1, y2 - y1);
current++; current++;
let progress = (current / total) * 100; let progress = (current / total) * 100;
sendprogress(progress); sendprogress(progress);
@@ -327,7 +204,6 @@ self.addEventListener("message", async (e) => {
scaled.width - pad_right[i] * factor, scaled.width - pad_right[i] * factor,
scaled.height - pad_bottom[j] * factor scaled.height - pad_bottom[j] * factor
); );
// console.log(i, j, x2 - x1, y2 - y1);
current++; current++;
let progress = (current / total) * 100; let progress = (current / total) * 100;
sendprogress(progress); sendprogress(progress);
@@ -340,21 +216,25 @@ self.addEventListener("message", async (e) => {
const start = Date.now(); const start = Date.now();
let output; let output;
try { try {
// if (data?.fixed) {
// output = await enlargeImageWithFixedInput(model, input, factor);
// } else {
// output = await enlargeImage(model, input, factor);
// }
output = await enlargeImageWithFixedInput(model, input, factor); output = await enlargeImageWithFixedInput(model, input, factor);
} catch (e) { } catch (e) {
postMessage({ alertmsg: e.toString() }); postMessage({ alertmsg: e.toString() });
} }
const end = Date.now(); const end = Date.now();
console.log("Time:", end - start); console.log("Time:", end - start);
await new Promise((resolve) => setTimeout(resolve, 10));
postMessage({ postMessage({
progress: 100, progress: 100,
done: true,
output: output,
info: `Processing image...`, info: `Processing image...`,
}); });
postMessage(
{
progress: 100,
done: true,
output: output.data.buffer,
info: `Processing image...`,
},
[output.data.buffer]
);
console.log("output");
}); });