Onnx web runtime 구축하기 with Vite
소개
Vission AI를 사용자 로컬 환경에서 실행하기 위해 WebGPU, WebGL, WASM을 통해 Web환경에서 AI Model을 실행한 방법들과 Model을 캐싱하며 버전관리를 하는 방법에 대해 작성하였다.
배경
의료용 AI 모바일 애플리케이션으로 X-ray, C-arm 등의 이미지를 스마트폰으로 촬영해서 이미지를 원상태로 펼쳐주고 각각의 뼈 각도를 측정할 수 있는 OnDevice AI 서비스를 개발하여 테스트 중이다. 하지만 현업 의사분들이 사용 시 노트북에 있는 이미지 파일을 바로 확인하여 다음날 봐야할 환자들의 정보를 미리 확인할 수 있도록 하는 기능이 필요하였고, 이를 위해 Web runtime onnx를 사용하여 Web browser에 AI Model을 캐싱하고 사용자의 로컬환경에서 실행할 수 있도록 적용하여야했다.
아래 내용은 Web runtime onnx를 실행하기 위해 찾아본 정보들을 정리한 내용이다.
Get Started
ONNX Runtime Web을 적용하는 방법으로 공식사이트에서는 아래의 방법들을 소개한다.
JavaScript import statement
Install
- npm
- yarn
# install latest release version
npm install onnxruntime-web
# install nightly build dev version
npm install onnxruntime-web@dev
# install latest release version
yarn add onnxruntime-web
# install nightly build dev version
yarn add onnxruntime-web@dev
Import
import * as ort from "onnxruntime-web";
하지만 위 방법대로만 적용한다면
no available backend found. ERR: [webgpu] RuntimeError: Aborted(both async and sync fetching of the wasm failed). Build with -sASSERTIONS for more info., [webgl] backend not found., [wasm] Error: previous call to 'initWasm()' failed.
와 같이 no available backend found
에러가 발생한다. 현재 프로젝트 환경은 React Router V7 + Vite
이므로 이에 추가적인 설정이 필요하다
...
export default defineConfig({
...
...
assetsInclude: ["**/*.onnx"],
optimizeDeps: {
exclude: ["onnxruntime-web"],
},
});
-
assetsInclude: ["**/*.onnx"]
.onnx
파일을 정적 자산(asset)으로 처리- 해당 파일들은 빌드 시
dist/
폴더로 복사되고 URL을 통해 접근 가능 - 모델 파일을
import
하거나public
폴더에서fetch
할 수 있도록 해줌
-
optimizeDeps: exclude: ["onnxruntime-web"]
onnxruntime-web
패키지를 Vite 사전 번들링(pre-bundling)에서 제외- 이 패키지는 WASM 파일과 복잡한 로딩 로직을 포함하므로 런타임에 동적으로 로드되어야 함
- Vite가 이 패키지를 ESM으로 변환하려다 실패하는 것을 방지
간단한 프로젝트라면 CDN 에서 script 를 받아와 js
로 실행해주는 것도 가능하다.
HTML script tag
Import
<script src="https://example.com/path/ort.webgpu.min.js"></script>
사용법
이 경우 window.
의 ort
를 가져와서 사용하여야 한다.
예시
const session = await window.ort.InferenceSession.create(modelSource as any, {
executionProviders: ["webgpu", "webgl", "wasm"],
});
하지만 타입 안전성, 번들 최적화, 오프라인 지원, 의존성 관리, 보안 등의 문제로 이번 프로젝트에서는 Package import
방식을 적용하도록 한다.
모델 로드 및 추론 실행
기본 사용법
ONNX Runtime Web을 사용하여 모델을 로드하고 추론을 실행하는 기본적인 방법은 다음과 같다:
import * as ort from "onnxruntime-web";
// 모델 로드
async function loadModel(modelUrl: string) {
try {
const session = await ort.InferenceSession.create(modelUrl, {
executionProviders: ["webgpu", "webgl", "wasm"],
});
return session;
} catch (error) {
console.error("모델 로드 실패:", error);
throw error;
}
}
// 추론 실행
async function runInference(
session: ort.InferenceSession,
inputData: Float32Array,
inputShape: number[]
) {
try {
// 입력 텐서 생성
const inputTensor = new ort.Tensor("float32", inputData, inputShape);
// 추론 실행
const feeds = { input: inputTensor }; // 모델의 입력 이름에 맞게 수정 필요
const results = await session.run(feeds);
return results;
} catch (error) {
console.error("추론 실행 실패:", error);
throw error;
}
}
이미지 전처리 예제
Vision AI 모델의 경우 이미지 전처리가 필요하다:
// 이미지를 모델 입력 형식으로 전처리
function preprocessImage(
imageElement: HTMLImageElement,
targetSize: [number, number] = [224, 224]
): Float32Array {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d")!;
canvas.width = targetSize[0];
canvas.height = targetSize[1];
// 이미지를 캔버스에 그리기 (리사이즈)
ctx.drawImage(imageElement, 0, 0, targetSize[0], targetSize[1]);
// 픽셀 데이터 추출
const imageData = ctx.getImageData(0, 0, targetSize[0], targetSize[1]);
const pixels = imageData.data;
// RGB 정규화 및 채널 분리 (CHW 형식)
const float32Data = new Float32Array(3 * targetSize[0] * targetSize[1]);
for (let i = 0; i < pixels.length; i += 4) {
const pixelIndex = i / 4;
const r = pixels[i] / 255.0;
const g = pixels[i + 1] / 255.0;
const b = pixels[i + 2] / 255.0;
// CHW 형식으로 데이터 배치
float32Data[pixelIndex] = r; // R 채널
float32Data[pixelIndex + targetSize[0] * targetSize[1]] = g; // G 채널
float32Data[pixelIndex + 2 * targetSize[0] * targetSize[1]] = b; // B 채널
}
return float32Data;
}
완전한 사용 예제
async function classifyImage(imageElement: HTMLImageElement) {
try {
// 1. 모델 로드
const session = await loadModel("/models/resnet50.onnx");
// 2. 이미지 전처리
const inputData = preprocessImage(imageElement);
const inputShape = [1, 3, 224, 224]; // [batch, channels, height, width]
// 3. 추론 실행
const results = await runInference(session, inputData, inputShape);
// 4. 결과 처리
const outputTensor = results[Object.keys(results)[0]]; // 첫 번째 출력
const predictions = outputTensor.data as Float32Array;
// 최대값의 인덱스 찾기 (가장 높은 확률의 클래스)
const maxIndex = predictions.indexOf(Math.max(...predictions));
return {
classIndex: maxIndex,
confidence: predictions[maxIndex],
allPredictions: predictions,
};
} catch (error) {
console.error("이미지 분류 실패:", error);
throw error;
}
}