Skip to content

Commit

Permalink
fix: memoization not working properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielHauschildt committed Jul 12, 2023
1 parent 1baba2d commit a3c722f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@imgly/background-removal",
"version": "1.0.4",
"version": "1.0.5",
"description": "Background Removal in the Browser",
"keywords": [
"background-removal",
Expand Down
33 changes: 26 additions & 7 deletions src/browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,41 @@ export type { ImageSource, Config };
// Imports
import { runInference } from './inference';
import { Config, validateConfig } from './schema';

import { createOnnxRuntime } from './ort-web-rt';
import * as utils from './utils';
import * as Bundle from './bundle';
import { Imports } from './tensor';

import { memoize } from 'lodash';

type ImageSource = ImageData | ArrayBuffer | Uint8Array | Blob | URL | string;

const memoizedCreateOnnxRuntime = memoize(createOnnxRuntime);
async function createSession(config: Config, imports: Imports) {
if (config.debug) console.debug('Loading model...');
console.log('createSession called');
console.log(config);
console.log(imports);
const model = config.model;
const blob = await Bundle.load(model, config);
const arrayBuffer = await blob.arrayBuffer();
const session = await imports.createSession(arrayBuffer);
return session;
}

async function _init(config?: Config) {
config = validateConfig(config);
const imports = createOnnxRuntime(config);
const session = await createSession(config, imports);
return { config, imports, session };
}

const init = memoize(_init, (config) => JSON.stringify(config));

async function removeBackground(
image: ImageSource,
config?: Config
configuration?: Config
): Promise<Blob> {
config = validateConfig(config);
const { config, imports, session } = await init(configuration);

if (config.debug) {
config.progress =
Expand All @@ -34,8 +55,6 @@ async function removeBackground(
}
}

const imports = memoizedCreateOnnxRuntime(config);

image = await utils.imageSourceToImageData(image);

if (!(image instanceof ImageData)) {
Expand All @@ -44,7 +63,7 @@ async function removeBackground(
);
}

const imageData = await runInference(image, config, imports);
const imageData = await runInference(image, config, imports, session);

return await utils.imageEncode(imageData);
}
7 changes: 3 additions & 4 deletions src/bundle.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
export { preload, fetchKey as fetch };
export { preload, load };

import { Config } from './schema';
import _ from 'lodash';

type Entry = {
url: string;
Expand Down Expand Up @@ -83,7 +82,7 @@ const bundle: Map<string, Entry> = new Map([
]
]);

async function fetchKey(key: string, config: Config) {
async function load(key: string, config: Config) {
const entry = bundle.get(key)!;
let url = entry.url;
if (config.publicPath) {
Expand Down Expand Up @@ -134,7 +133,7 @@ async function preload(config: Config) {
// This will warmup the caches
let result = new Map(bundle);
result.forEach(async (_, key) => {
await fetchKey(key, config);
await load(key, config);
});
return result;
}
15 changes: 2 additions & 13 deletions src/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,13 @@ import { imageDataResize, imageDataToFloat32Array } from './utils';
import { Imports } from './tensor';
import { calculateProportionalSize } from './utils';
import { Config } from './schema';
import * as Bundle from './bundle';

import { memoize } from 'lodash';

export async function runInference(
imageData: ImageData,
config: Config,
imports: Imports
imports: Imports,
session: any
): Promise<ImageData> {
const session = await memoize(async (config: Config, imports: Imports) => {
if (config.debug) console.debug('Loading model...');
const model = config.model;
const blob = await Bundle.fetch(model, config);
const arrayBuffer = await blob.arrayBuffer();
const session = await imports.createSession(arrayBuffer);
return session;
})(config, imports);

if (config.progress) config.progress('compute:inference', 0, 1);
const resolution = 1024;
const src_width = imageData.width;
Expand Down
12 changes: 6 additions & 6 deletions src/ort-web-rt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ function createOnnxRuntime(config: any): Imports {
ort.env.wasm.proxy = config.proxyToWorker;
ort.env.wasm.wasmPaths = {
// 'ort-wasm-simd-threaded.jsep.wasm': URL.createObjectURL(
// await Bundle.fetch('ort-wasm-simd-threaded.jsep.wasm', config)
// await Bundle.load('ort-wasm-simd-threaded.jsep.wasm', config)
// ),
// 'ort-wasm-simd.jsep.wasm': URL.createObjectURL(
// await Bundle.fetch('ort-wasm-simd.jsep.wasm', config)
// await Bundle.load('ort-wasm-simd.jsep.wasm', config)
// ),
'ort-wasm-simd-threaded.wasm': URL.createObjectURL(
await Bundle.fetch('ort-wasm-simd-threaded.wasm', config)
await Bundle.load('ort-wasm-simd-threaded.wasm', config)
),
'ort-wasm-simd.wasm': URL.createObjectURL(
await Bundle.fetch('ort-wasm-simd.wasm', config)
await Bundle.load('ort-wasm-simd.wasm', config)
),
'ort-wasm-threaded.wasm': URL.createObjectURL(
await Bundle.fetch('ort-wasm-threaded.wasm', config)
await Bundle.load('ort-wasm-threaded.wasm', config)
),
'ort-wasm.wasm': URL.createObjectURL(
await Bundle.fetch('ort-wasm.wasm', config)
await Bundle.load('ort-wasm.wasm', config)
)
};

Expand Down

0 comments on commit a3c722f

Please sign in to comment.