阅读视图

发现新文章,点击刷新页面。

【纯前端推理】纯端侧 AI 对象检测:用浏览器就能跑的深度学习模型

🚀 纯端侧 AI 对象检测:用浏览器就能跑的深度学习模型

前言

随着 WebAssembly 和浏览器 GPU 加速技术的快速发展,在浏览器中直接运行深度学习模型已经成为现实。这一突破的关键在于 ONNX Web Runtime,它作为微软开源的跨平台机器学习推理引擎,能够将训练好的 ONNX 模型高效地运行在浏览器环境中。

ONNX Web Runtime 通过以下技术实现了浏览器端的高性能 AI 推理:

  • WebAssembly (WASM):提供接近原生性能的计算能力
  • WebGL:利用 GPU 进行并行计算加速
  • WebGPU:下一代 GPU 计算标准,性能更强劲
  • CPU 优化:针对不同架构的专门优化

基于这些底层技术,结合 Hugging Face Transformers.js 的高级封装,我们今天来实现一个完全运行在浏览器端的对象检测应用。整个过程无需服务器、无需上传数据,在保护用户隐私的同时,让每个人都能便捷地体验 AI 的强大能力。

🔬 技术原理:Transformers.js 在底层调用 ONNX Web Runtime,将 PyTorch/TensorFlow 训练的模型转换为 ONNX 格式,然后在浏览器中进行推理计算。

🎯 效果展示

效果图.png

上传一张图片,AI 立即识别出图中的对象,并用彩色边框标注出来

🔧 技术栈

  • 前端框架: TypeScript + Vite
  • AI 推理: Hugging Face Transformers.js
  • 模型: DETR ResNet-50 (Facebook 开源)

📖 端侧推理的优势

graph TB
    A[用户上传图片] --> B[浏览器加载模型]
    B --> C[本地推理计算]
    C --> D[返回检测结果]
    D --> E[Canvas可视化]

    F[传统方案] --> G[上传到服务器]
    G --> H[服务器推理]
    H --> I[返回结果]

    style A fill:#e1f5fe
    style E fill:#c8e6c9
    style F fill:#ffebee
    style I fill:#ffcdd2

端侧推理的优势:

  • 隐私保护:数据不离开设备
  • 实时响应:无网络延迟
  • 成本控制:无服务器费用
  • 离线可用:无需网络连接

端侧推理 VS 服务端推理:

特性 端侧推理 服务器推理
隐私保护 ✅ 完全本地 ❌ 需要上传
成本 ✅ 零服务器成本 ❌ 需要 GPU 服务器
离线使用 ✅ 支持离线 ❌ 需要网络
模型更新 ❌ 需要重新下载 ✅ 服务器更新
计算资源 ❌ 依赖设备性能 ✅ 专业 GPU
响应速度 ❌ 依赖端侧计算能力 ✅ 专业 GPU

🔨 核心代码实现

使用 Vite 新建项目

pnpm create vite

安装 transformers

pnpm i @huggingface/transformers

调用模型,实现推理 模型地址: huggingface.co/Xenova/detr…

/** 加载模型 */
const detector = await pipeline("object-detection", "Xenova/detr-resnet-50", {
  dtype: "auto",
});

/** 设置阈值 分数小于 0.9 的不展示 */
const output = await detector(imageUrl, { threshold: 0.9 });

console.log("Detection results:", output);

/** 推理的结果是一个由下面接口组成的数组 */
interface DetectionResult {
  label: string;
  score: number;
  box: {
    xmin: number;
    ymin: number;
    xmax: number;
    ymax: number;
  };
}

完整代码

import { pipeline } from "@huggingface/transformers";

interface CreateCanvasOptions {
  width: number;
  height: number;
}

interface DetectionResult {
  label: string;
  score: number;
  box: {
    xmin: number;
    ymin: number;
    xmax: number;
    ymax: number;
  };
}

type Detector = (
  image: string,
  options: { threshold: number }
) => Promise<DetectionResult[]>;

let detector: unknown | null = null;

async function initializeModel() {
  if (!detector) {
    console.time("Model loading");
    detector = await pipeline("object-detection", "Xenova/detr-resnet-50", {
      dtype: "auto",
    });
    console.timeEnd("Model loading");
  }

  return detector;
}

async function detectObjects(imageUrl: string) {
  const model = await initializeModel();
  const detector = model as Detector;

  console.time("Detection");
  const output = await detector(imageUrl, { threshold: 0.9 });
  console.timeEnd("Detection");

  console.log("Detection results:", output);

  return output.map((detection: DetectionResult) => ({
    label: detection.label,
    score: detection.score,
    box: {
      xmin: detection.box.xmin,
      ymin: detection.box.ymin,
      xmax: detection.box.xmax,
      ymax: detection.box.ymax,
    },
  }));
}

function loadImage(url: string) {
  return new Promise<HTMLImageElement>((resolve, reject) => {
    const image = new Image();
    image.crossOrigin = "anonymous";
    image.src = url;

    image.addEventListener("load", () => {
      resolve(image);
    });

    image.addEventListener("error", () => {
      reject(new Error("Failed to load image"));
    });
  });
}

function createCanvas(options: CreateCanvasOptions) {
  const canvas = document.createElement("canvas");
  const context = canvas.getContext("2d");

  if (!context) {
    throw new Error("Failed to get canvas context");
  }

  const devicePixelRatio = window.devicePixelRatio || 1;
  canvas.width = options.width * devicePixelRatio;
  canvas.height = options.height * devicePixelRatio;

  context.scale(devicePixelRatio, devicePixelRatio);

  return {
    canvas,
    context,
    devicePixelRatio,
  };
}

function drawDetections(
  context: CanvasRenderingContext2D,
  detections: DetectionResult[]
) {
  for (const [index, detection] of detections.entries()) {
    const { box, label, score } = detection;
    const colors = [
      "#FF6B6B",
      "#4ECDC4",
      "#45B7D1",
      "#96CEB4",
      "#FFEAA7",
      "#DDA0DD",
      "#98D8C8",
    ];
    const color = colors[index % colors.length];

    /** 画框 */
    const lineWidth = 2;
    const halfLineWidth = lineWidth / 2;
    const textPadding = 10;
    context.strokeStyle = color;
    context.lineWidth = lineWidth;
    context.strokeRect(
      box.xmin,
      box.ymin,
      box.xmax - box.xmin,
      box.ymax - box.ymin
    );

    /** 画标签 */
    context.fillStyle = color;
    const labelText = `${label.toUpperCase()} (${(score * 100).toFixed(1)}%)`;
    context.font = "16px Arial";
    const metrics = context.measureText(labelText);
    const textHeight =
      metrics.actualBoundingBoxAscent + metrics.actualBoundingBoxDescent;
    const textWidth = metrics.width;
    context.fillRect(
      box.xmin - halfLineWidth,
      box.ymin - textHeight - halfLineWidth - textPadding,
      textWidth + textPadding,
      textHeight + textPadding
    );

    /** 画标签文本 */
    context.fillStyle = "white";
    context.textAlign = "left";
    context.textBaseline = "top";
    context.fillText(
      labelText,
      box.xmin + textPadding / 2,
      box.ymin - textHeight - textPadding / 2
    );
  }
}

async function detectImage(url: string) {
  const image = await loadImage(url);
  const { canvas, context } = createCanvas({
    width: image.naturalWidth,
    height: image.naturalHeight,
  });

  context.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight);
  canvas.style.display = "block";
  canvas.style.margin = "50px auto";
  if (image.naturalWidth > image.naturalHeight) {
    canvas.style.width = `${window.innerWidth * 0.85}px`;
  } else {
    canvas.style.height = `${window.innerHeight * 0.85}px`;
  }

  document.body.appendChild(canvas);

  const detections = await detectObjects(image.src);

  drawDetections(context, detections);
}

async function selectImageAndCreateBlobUrl(options?: {
  accept?: string;
  multiple?: boolean;
}): Promise<string> {
  const { accept = "image/*", multiple = false } = options || {};

  return new Promise((resolve, reject) => {
    const fileInput = document.createElement("input");
    fileInput.type = "file";
    fileInput.accept = accept;
    fileInput.multiple = multiple;
    fileInput.style.display = "none";

    fileInput.onchange = (event) => {
      const target = event.target as HTMLInputElement;
      const file = target.files?.[0];

      if (!file) {
        reject(new Error("No file selected"));
        return;
      }

      // 验证文件类型
      if (!file.type.startsWith("image/")) {
        reject(new Error("Selected file is not an image"));
        return;
      }

      // 创建 blob URL
      const blobUrl = URL.createObjectURL(file);

      // 清理 DOM
      document.body.removeChild(fileInput);

      resolve(blobUrl);
    };

    fileInput.oncancel = () => {
      document.body.removeChild(fileInput);
      reject(new Error("File selection cancelled"));
    };

    // 添加到 DOM 并触发点击
    document.body.appendChild(fileInput);
    fileInput.click();
  });
}

const selectImageButton = document.getElementById("select-image")!;

selectImageButton.addEventListener("click", async () => {
  const blobUrl = await selectImageAndCreateBlobUrl();
  detectImage(blobUrl);
});

🎯 总结

通过 Hugging Face Transformers.js 和现代浏览器的强大能力,我们成功实现了完全运行在浏览器端的对象检测应用。这种方案具有以下优势:

  • 🔒 隐私安全:数据不离开设备
  • 💰 成本效益:无服务器费用
  • 📱 易于部署:静态网站即可

随着 WebGPU、WebAssembly 等技术的发展,端侧 AI 推理将会变得更加高效和普及。这不仅为开发者提供了新的可能性,也为用户带来了更好的隐私保护和使用体验。


🔗 相关资源

💡 提示:如果你觉得这篇文章对你有帮助,请点个赞支持一下!有任何问题欢迎在评论区讨论。

Tags: JavaScript TypeScript 机器学习 人工智能 Web开发 前端 Transformers.js 对象检测

❌