普通视图

发现新文章,点击刷新页面。
昨天 — 2025年7月1日首页

在Swift中运行Silero VAD

作者 ZHANGYU
2025年7月1日 10:26

最近又开始学习Swift了,前段时间在AI的帮助下做了一个可以和大模型聊天的软件,当时VAD的功能很头痛,搜了下有一个付费的Cobra VAD,另外就只有靠音频能量判断了,这种方式不准。

最近做的东西又有VAD需求了,研究了很久后可以在Swift里跑Silero VAD了,直接把代码丢出来。

由于我不知道如何把ONNX模型转成Core ML的,官方ONNX Runtime只有Pods的包,我用的是另一个Swift Packags版本的ONNX Runtime,用Pods的包要把import OnnxRuntimeBindings换一下。

//
//  SileroVAD.swift
//  Real-time Captions
//
//  Created by yu on 2025/6/30.
//

import AVFoundation
import Foundation
import OnnxRuntimeBindings

/// 说话起止事件回调
protocol SileroVADDelegate: AnyObject {
    /// 检测到"开始说话"
    /// - Parameter probability: 触发时那一帧的 VAD 概率
    func vadDidStartSpeech(probability: Float)

    /// 检测到"结束说话"
    /// - Parameter probability: 触发时那一帧的 VAD 概率
    func vadDidEndSpeech(probability: Float)
}

final class SileroVAD {
    // MARK: - 可调参数

    public struct Config {
        /// 进入说话的高阈值
        public var threshold: Float = 0.5
        /// 退出说话的低阈值(自动与 threshold 保持 0.15 差值)
        public var negThreshold: Float { max(threshold - 0.15, 0.01) }
        /// 连续多长时间高于 threshold 才算"开始说话"(秒)
        public var startSecs: Float = 0.20
        /// 连续多长时间低于 negThreshold 才算"结束说话"(秒)
        public var stopSecs: Float = 0.80
        /// 采样率,仅支持 8 kHz / 16 kHz
        public var sampleRate: Int = 16000

        public init() {}
    }

    // MARK: - 内部状态

    private enum VADState {
        case silence // 静音状态
        case speechCandidate // 可能开始说话
        case speech // 正在说话
        case silenceCandidate // 可能结束说话
    }

    private enum VADError: Error {
        case modelLoadFailed(String)
        case invalidAudioFormat(String)
        case inferenceError(String)
        case tensorCreationFailed(String)
    }

    // MARK: - 核心属性

    private let session: ORTSession
    private var state: ORTValue
    private let config: Config
    public weak var delegate: SileroVADDelegate?

    // 状态机相关
    private var vadState: VADState = .silence
    private var speechFrameCount = 0
    private var silenceFrameCount = 0
    private var lastProbability: Float = 0.0

    // 阈值(基于配置计算的帧数)
    private let speechFrameThreshold: Int
    private let silenceFrameThreshold: Int

    // 音频缓冲
    private var sampleBuffer: [Float] = []
    private let bufferSize = 512

    // MARK: - 公有方法

    public init(config: Config = Config(), delegate: SileroVADDelegate? = nil) {
        self.config = config
        self.delegate = delegate

        // 计算帧数阈值(基于配置动态计算窗口时长)
        let windowDurationSecs = Float(bufferSize) / Float(config.sampleRate)
        speechFrameThreshold = Int(config.startSecs / windowDurationSecs)
        silenceFrameThreshold = Int(config.stopSecs / windowDurationSecs)

        guard let modelPath = Bundle.main.path(forResource: "silero_vad", ofType: "onnx") else {
            fatalError("SileroVAD: Model file not found in bundle")
        }

        do {
            let env = try ORTEnv(loggingLevel: .warning)
            let sessionOptions = try ORTSessionOptions()

            // 性能优化配置
            try sessionOptions.setGraphOptimizationLevel(.all)
            try sessionOptions.setIntraOpNumThreads(Int32(ProcessInfo.processInfo.processorCount))

            // 尝试启用Core ML硬件加速
            do {
                let coreMLOptions = ORTCoreMLExecutionProviderOptions()
                try sessionOptions.appendCoreMLExecutionProvider(with: coreMLOptions)
                print("SileroVAD: Using Core ML Execution Provider (Neural Engine/NPU)")
            } catch {
                print("SileroVAD: Using optimized CPU execution with \(ProcessInfo.processInfo.processorCount) cores")
            }

            session = try ORTSession(env: env, modelPath: modelPath, sessionOptions: sessionOptions)

        } catch {
            fatalError("SileroVAD: Failed to create ONNX session: \(error)")
        }

        // 初始化RNN状态 (shape: 2, 1, 128)
        let stateData = Array(repeating: Float(0.0), count: 2 * 1 * 128)
        do {
            state = try ORTValue(tensorData: NSMutableData(data: Data(bytes: stateData, count: stateData.count * 4)),
                                 elementType: .float,
                                 shape: [2, 1, 128])
        } catch {
            fatalError("SileroVAD: Failed to create initial state tensor: \(error)")
        }
    }

    /// 输入音频样本,自动处理状态检测
    public func feed(_ samples: [Float]) {
        sampleBuffer.append(contentsOf: samples)

        // 当有足够样本时自动检测
        while sampleBuffer.count >= bufferSize {
            if let probability = performDetection() {
                updateVADState(probability: probability)
            }
        }
    }

    /// 重置内部状态机 & RNN 隐状态
    public func reset() {
        // 重置状态机
        vadState = .silence
        speechFrameCount = 0
        silenceFrameCount = 0
        lastProbability = 0.0

        // 清空缓冲区
        sampleBuffer.removeAll()

        // 重置RNN状态
        let stateData = Array(repeating: Float(0.0), count: 2 * 1 * 128)
        do {
            state = try ORTValue(tensorData: NSMutableData(data: Data(bytes: stateData, count: stateData.count * 4)),
                                 elementType: .float,
                                 shape: [2, 1, 128])
        } catch {
            print("SileroVAD: Failed to reset state tensor: \(error)")
        }
    }

    // MARK: - 私有方法

    private func performDetection() -> Float? {
        guard sampleBuffer.count >= bufferSize else {
            return nil
        }

        // 取出一个窗口的样本
        let vadInput = Array(sampleBuffer.prefix(bufferSize))
        sampleBuffer.removeFirst(bufferSize)

        do {
            let probability = try runInference(audioData: vadInput)
            lastProbability = probability
            return probability
        } catch {
            print("SileroVAD: Detection error: \(error)")
            return nil
        }
    }

    private func runInference(audioData: [Float]) throws -> Float {
        guard audioData.count == 512 else {
            throw VADError.invalidAudioFormat("Audio data must be exactly 512 samples")
        }

        // 创建输入张量
        let inputTensor = try ORTValue(
            tensorData: NSMutableData(data: Data(bytes: audioData, count: audioData.count * 4)),
            elementType: .float,
            shape: [1, 512]
        )

        // 创建采样率张量
        var srData = Int64(config.sampleRate)
        let srTensor = try ORTValue(
            tensorData: NSMutableData(data: Data(bytes: &srData, count: 8)),
            elementType: .int64,
            shape: [1]
        )

        // 准备输入
        let inputs: [String: ORTValue] = [
            "input": inputTensor,
            "state": state,
            "sr": srTensor,
        ]

        // 执行推理
        let allOutputNames = try session.outputNames()
        let outputs = try session.run(withInputs: inputs, outputNames: Set(allOutputNames), runOptions: nil)

        // 提取结果
        guard let outputTensor = outputs["output"] else {
            throw VADError.inferenceError("Missing 'output' tensor")
        }

        guard let newStateTensor = outputs["stateN"] else {
            throw VADError.inferenceError("Missing 'stateN' tensor")
        }

        // 更新状态
        state = newStateTensor

        // 提取概率值
        let tensorData = try outputTensor.tensorData() as Data
        let probability = tensorData.withUnsafeBytes { bytes in
            bytes.load(as: Float.self)
        }

        return probability
    }

    private func updateVADState(probability: Float) {
        let isHighProbability = probability >= config.threshold
        let isLowProbability = probability <= config.negThreshold

        switch vadState {
        case .silence:
            if isHighProbability {
                vadState = .speechCandidate
                speechFrameCount = 1
                silenceFrameCount = 0
            }

        case .speechCandidate:
            if isHighProbability {
                speechFrameCount += 1
                if speechFrameCount >= speechFrameThreshold {
                    vadState = .speech
                    delegate?.vadDidStartSpeech(probability: probability)
                }
            } else {
                vadState = .silence
                speechFrameCount = 0
            }

        case .speech:
            if isLowProbability {
                vadState = .silenceCandidate
                silenceFrameCount = 1
                speechFrameCount = 0
            } else if isHighProbability {
                // 继续说话,重置静音计数
                silenceFrameCount = 0
            }

        case .silenceCandidate:
            if isLowProbability {
                silenceFrameCount += 1
                if silenceFrameCount >= silenceFrameThreshold {
                    vadState = .silence
                    delegate?.vadDidEndSpeech(probability: probability)
                }
            } else if isHighProbability {
                vadState = .speech
                silenceFrameCount = 0
            }
        }
    }
}

要下载模型silero_vad.onnx丢进项目。

当然这个代码也是Claude帮我写的。

❌
❌