【纯前端推理】纯端侧 AI 对象检测:用浏览器就能跑的深度学习模型
🚀 纯端侧 AI 对象检测:用浏览器就能跑的深度学习模型
前言
随着 WebAssembly 和浏览器 GPU 加速技术的快速发展,在浏览器中直接运行深度学习模型已经成为现实。这一突破的关键在于 ONNX Web Runtime,它作为微软开源的跨平台机器学习推理引擎,能够将训练好的 ONNX 模型高效地运行在浏览器环境中。
ONNX Web Runtime 通过以下技术实现了浏览器端的高性能 AI 推理:
- WebAssembly (WASM):提供接近原生性能的计算能力
- WebGL:利用 GPU 进行并行计算加速
- WebGPU:下一代 GPU 计算标准,性能更强劲
- CPU 优化:针对不同架构的专门优化
基于这些底层技术,结合 Hugging Face Transformers.js 的高级封装,我们今天来实现一个完全运行在浏览器端的对象检测应用。整个过程无需服务器、无需上传数据,在保护用户隐私的同时,让每个人都能便捷地体验 AI 的强大能力。
🔬 技术原理:Transformers.js 在底层调用 ONNX Web Runtime,将 PyTorch/TensorFlow 训练的模型转换为 ONNX 格式,然后在浏览器中进行推理计算。
🎯 效果展示
上传一张图片,AI 立即识别出图中的对象,并用彩色边框标注出来
🔧 技术栈
- 前端框架: TypeScript + Vite
- AI 推理: Hugging Face Transformers.js
- 模型: DETR ResNet-50 (Facebook 开源)
📖 端侧推理的优势
graph TB
A[用户上传图片] --> B[浏览器加载模型]
B --> C[本地推理计算]
C --> D[返回检测结果]
D --> E[Canvas可视化]
F[传统方案] --> G[上传到服务器]
G --> H[服务器推理]
H --> I[返回结果]
style A fill:#e1f5fe
style E fill:#c8e6c9
style F fill:#ffebee
style I fill:#ffcdd2
端侧推理的优势:
- ✅ 隐私保护:数据不离开设备
- ✅ 实时响应:无网络延迟
- ✅ 成本控制:无服务器费用
- ✅ 离线可用:无需网络连接
端侧推理 VS 服务端推理:
特性 | 端侧推理 | 服务器推理 |
---|---|---|
隐私保护 | ✅ 完全本地 | ❌ 需要上传 |
成本 | ✅ 零服务器成本 | ❌ 需要 GPU 服务器 |
离线使用 | ✅ 支持离线 | ❌ 需要网络 |
模型更新 | ❌ 需要重新下载 | ✅ 服务器更新 |
计算资源 | ❌ 依赖设备性能 | ✅ 专业 GPU |
响应速度 | ❌ 依赖端侧计算能力 | ✅ 专业 GPU |
🔨 核心代码实现
使用 Vite 新建项目
pnpm create vite
安装 transformers
pnpm i @huggingface/transformers
调用模型,实现推理 模型地址: huggingface.co/Xenova/detr…
/** 加载模型 */
const detector = await pipeline("object-detection", "Xenova/detr-resnet-50", {
dtype: "auto",
});
/** 设置阈值 分数小于 0.9 的不展示 */
const output = await detector(imageUrl, { threshold: 0.9 });
console.log("Detection results:", output);
/** 推理的结果是一个由下面接口组成的数组 */
interface DetectionResult {
label: string;
score: number;
box: {
xmin: number;
ymin: number;
xmax: number;
ymax: number;
};
}
完整代码
import { pipeline } from "@huggingface/transformers";
interface CreateCanvasOptions {
width: number;
height: number;
}
interface DetectionResult {
label: string;
score: number;
box: {
xmin: number;
ymin: number;
xmax: number;
ymax: number;
};
}
type Detector = (
image: string,
options: { threshold: number }
) => Promise<DetectionResult[]>;
let detector: unknown | null = null;
async function initializeModel() {
if (!detector) {
console.time("Model loading");
detector = await pipeline("object-detection", "Xenova/detr-resnet-50", {
dtype: "auto",
});
console.timeEnd("Model loading");
}
return detector;
}
async function detectObjects(imageUrl: string) {
const model = await initializeModel();
const detector = model as Detector;
console.time("Detection");
const output = await detector(imageUrl, { threshold: 0.9 });
console.timeEnd("Detection");
console.log("Detection results:", output);
return output.map((detection: DetectionResult) => ({
label: detection.label,
score: detection.score,
box: {
xmin: detection.box.xmin,
ymin: detection.box.ymin,
xmax: detection.box.xmax,
ymax: detection.box.ymax,
},
}));
}
function loadImage(url: string) {
return new Promise<HTMLImageElement>((resolve, reject) => {
const image = new Image();
image.crossOrigin = "anonymous";
image.src = url;
image.addEventListener("load", () => {
resolve(image);
});
image.addEventListener("error", () => {
reject(new Error("Failed to load image"));
});
});
}
function createCanvas(options: CreateCanvasOptions) {
const canvas = document.createElement("canvas");
const context = canvas.getContext("2d");
if (!context) {
throw new Error("Failed to get canvas context");
}
const devicePixelRatio = window.devicePixelRatio || 1;
canvas.width = options.width * devicePixelRatio;
canvas.height = options.height * devicePixelRatio;
context.scale(devicePixelRatio, devicePixelRatio);
return {
canvas,
context,
devicePixelRatio,
};
}
function drawDetections(
context: CanvasRenderingContext2D,
detections: DetectionResult[]
) {
for (const [index, detection] of detections.entries()) {
const { box, label, score } = detection;
const colors = [
"#FF6B6B",
"#4ECDC4",
"#45B7D1",
"#96CEB4",
"#FFEAA7",
"#DDA0DD",
"#98D8C8",
];
const color = colors[index % colors.length];
/** 画框 */
const lineWidth = 2;
const halfLineWidth = lineWidth / 2;
const textPadding = 10;
context.strokeStyle = color;
context.lineWidth = lineWidth;
context.strokeRect(
box.xmin,
box.ymin,
box.xmax - box.xmin,
box.ymax - box.ymin
);
/** 画标签 */
context.fillStyle = color;
const labelText = `${label.toUpperCase()} (${(score * 100).toFixed(1)}%)`;
context.font = "16px Arial";
const metrics = context.measureText(labelText);
const textHeight =
metrics.actualBoundingBoxAscent + metrics.actualBoundingBoxDescent;
const textWidth = metrics.width;
context.fillRect(
box.xmin - halfLineWidth,
box.ymin - textHeight - halfLineWidth - textPadding,
textWidth + textPadding,
textHeight + textPadding
);
/** 画标签文本 */
context.fillStyle = "white";
context.textAlign = "left";
context.textBaseline = "top";
context.fillText(
labelText,
box.xmin + textPadding / 2,
box.ymin - textHeight - textPadding / 2
);
}
}
async function detectImage(url: string) {
const image = await loadImage(url);
const { canvas, context } = createCanvas({
width: image.naturalWidth,
height: image.naturalHeight,
});
context.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight);
canvas.style.display = "block";
canvas.style.margin = "50px auto";
if (image.naturalWidth > image.naturalHeight) {
canvas.style.width = `${window.innerWidth * 0.85}px`;
} else {
canvas.style.height = `${window.innerHeight * 0.85}px`;
}
document.body.appendChild(canvas);
const detections = await detectObjects(image.src);
drawDetections(context, detections);
}
async function selectImageAndCreateBlobUrl(options?: {
accept?: string;
multiple?: boolean;
}): Promise<string> {
const { accept = "image/*", multiple = false } = options || {};
return new Promise((resolve, reject) => {
const fileInput = document.createElement("input");
fileInput.type = "file";
fileInput.accept = accept;
fileInput.multiple = multiple;
fileInput.style.display = "none";
fileInput.onchange = (event) => {
const target = event.target as HTMLInputElement;
const file = target.files?.[0];
if (!file) {
reject(new Error("No file selected"));
return;
}
// 验证文件类型
if (!file.type.startsWith("image/")) {
reject(new Error("Selected file is not an image"));
return;
}
// 创建 blob URL
const blobUrl = URL.createObjectURL(file);
// 清理 DOM
document.body.removeChild(fileInput);
resolve(blobUrl);
};
fileInput.oncancel = () => {
document.body.removeChild(fileInput);
reject(new Error("File selection cancelled"));
};
// 添加到 DOM 并触发点击
document.body.appendChild(fileInput);
fileInput.click();
});
}
const selectImageButton = document.getElementById("select-image")!;
selectImageButton.addEventListener("click", async () => {
const blobUrl = await selectImageAndCreateBlobUrl();
detectImage(blobUrl);
});
🎯 总结
通过 Hugging Face Transformers.js 和现代浏览器的强大能力,我们成功实现了完全运行在浏览器端的对象检测应用。这种方案具有以下优势:
- 🔒 隐私安全:数据不离开设备
- 💰 成本效益:无服务器费用
- 📱 易于部署:静态网站即可
随着 WebGPU、WebAssembly 等技术的发展,端侧 AI 推理将会变得更加高效和普及。这不仅为开发者提供了新的可能性,也为用户带来了更好的隐私保护和使用体验。
🔗 相关资源
💡 提示:如果你觉得这篇文章对你有帮助,请点个赞支持一下!有任何问题欢迎在评论区讨论。
Tags: JavaScript
TypeScript
机器学习
人工智能
Web开发
前端
Transformers.js
对象检测