diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0d02e1e3f141813a33c4caa96c9a7ed2362ef6f4..530356a7b25d0540c08d4dd60a0504deb0050f74 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -131,13 +131,42 @@ update-nightly: - .dist-debian-container - .debian:11-nightly + #.download-vad-test-files: + #stage: "test" + #script: + #- wget $VAD_LINK + #- mkdir -p $VAD_TEST_FILES + #- tar xf $VAD_SOURCE $VAD_MODEL + #- mv $VAD_MODEL $VAD_TEST_FILES/ + #- rm -rf $VAD_SOURCE $VAD_NAME-$VAD_VERSION + .cargo test: stage: "test" variables: # csound-sys only looks at /usr/lib and /usr/local top levels CSOUND_LIB_DIR: '/usr/lib/x86_64-linux-gnu/' RUST_BACKTRACE: 'full' + #VAD_VERSION: !reference [.download-vad-test-files, variables, VAD_VERSION] + #VAD_SOURCE: !reference [.download-vad-test-files, variables, VAD_SOURCE] + #VAD_NAME: !reference [.download-vad-test-files, variables, VAD_NAME] + #VAD_LINK: !reference [.download-vad-test-files, variables, VAD_LINK] + #VAD_TEST_FILES: !reference [.download-vad-test-files, variables, VAD_TEST_FILES] + #VAD_MODEL: !reference [.download-vad-test-files, variables, VAD_MODEL] + VAD_VERSION: '3.1' + VAD_SOURCE: 'v3.1.tar.gz' + VAD_NAME: 'silero-vad' + VAD_LINK: 'https://github.com/snakers4/silero-vad/archive/refs/tags/v3.1.tar.gz' + VAD_TEST_FILES: 'audio/vadonnx/tests/test_files' + VAD_MODEL: 'silero-vad-3.1/files/silero_vad.onnx' + script: + # - !reference [.download-vad-test-files, script] + - wget $VAD_LINK + - mkdir -p $VAD_TEST_FILES + - tar xf $VAD_SOURCE $VAD_MODEL + - mv $VAD_MODEL $VAD_TEST_FILES/ + - rm -rf $VAD_SOURCE $VAD_NAME-$VAD_VERSION + - rustc --version - cargo build --locked --color=always --workspace --all-targets diff --git a/Cargo.toml b/Cargo.toml index 9633c6c4e5b118737c41de6e9e9478d3ce9886b7..a9ecf836cbe153f26b89e5f6a47b2b7f043b9be8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "audio/csound", "audio/lewton", "audio/spotify", + "audio/vadonnx", "generic/fmp4", "generic/file", "generic/sodium", diff --git a/audio/vadonnx/Cargo.toml b/audio/vadonnx/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7d750fb42a227b58554ffae57fb7cf1c738bb488 --- /dev/null +++ b/audio/vadonnx/Cargo.toml @@ -0,0 +1,53 @@ +[package] +name = "gst-plugin-vadonnx" +version = "0.9.0" +authors = ["Abdul Rehman <7r3nzy@gmail.com>"] +repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs" +license = "MPL-2.0" +description = "Rust VAD Plugin" +edition = "2021" +rust-version = "1.57" + +[dependencies] +gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] } +gst-base = { package = "gstreamer-base", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] } +gst-audio = { package = "gstreamer-audio", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_20"] } +anyhow = "1" +tract-onnx = { git = "https://github.com/sonos/tract" } +once_cell = "1.10" +byte-slice-cast = "1.2.1" + +[lib] +name = "gstrsvadonnx" +crate-type = ["cdylib", "rlib"] +path = "src/lib.rs" + +[dev-dependencies] +gst-check = { package = "gstreamer-check", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_18"] } +hound = "3" + +[build-dependencies] +gst-plugin-version-helper = { path="../../version-helper" } + +[features] +static = [] +capi = [] +vad-tests = [] + + +[[test]] +name = "vadonnx" +required-features = ["vad-tests"] + +[package.metadata.capi] +min_version = "0.8.0" + +[package.metadata.capi.header] +enabled = false + +[package.metadata.capi.library] +install_subdir = "gstreamer-1.0" +versioning = false + +[package.metadata.capi.pkg_config] +requires_private = "gstreamer-1.0, gstreamer-base-1.0, gstreamer-audio-1.0, gobject-2.0, glib-2.0, gmodule-2.0" diff --git a/audio/vadonnx/LICENSE-MPL-2.0 b/audio/vadonnx/LICENSE-MPL-2.0 new file mode 120000 index 0000000000000000000000000000000000000000..eb5d24fe91cf96433d7a8685343bd62765b53fbb --- /dev/null +++ b/audio/vadonnx/LICENSE-MPL-2.0 @@ -0,0 +1 @@ +../../LICENSE-MPL-2.0 \ No newline at end of file diff --git a/audio/vadonnx/README.md b/audio/vadonnx/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6fa9bf76296d5bae272c58d0f98208a20b4ede65 --- /dev/null +++ b/audio/vadonnx/README.md @@ -0,0 +1,15 @@ +# Trying out VAD plugin + +## Getting the model + +```bash +curl -L0 -O https://github.com/snakers4/silero-vad/archive/refs/tags/v3.1.tar.gz +tar xvf v3.1.tar.gz +cp silero-vad-3.1/files/silero_vad.onnx . +``` + +## Running the pipeline + +```bash +gst-launch-1.0 filesrc location="your_audio.wav" ! decodebin ! audioconvert ! audioresample ! rsvadonnx model="silero_vad.onnx" ! fakesink sync=false async=false +``` diff --git a/audio/vadonnx/build.rs b/audio/vadonnx/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..cda12e57e199933e6ce75a6bd48e2ef443392eb2 --- /dev/null +++ b/audio/vadonnx/build.rs @@ -0,0 +1,3 @@ +fn main() { + gst_plugin_version_helper::info() +} diff --git a/audio/vadonnx/src/lib.rs b/audio/vadonnx/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..5d45aa0fb305361c3865afc21c26324fb0c0e81d --- /dev/null +++ b/audio/vadonnx/src/lib.rs @@ -0,0 +1,30 @@ +// Copyright (C) 2017 Sebastian Dröge +// +// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. +// If a copy of the MPL was not distributed with this file, You can obtain one at +// . +// +// SPDX-License-Identifier: MPL-2.0 +#![allow(clippy::non_send_fields_in_send_ty)] + +use gst::glib; + +mod vadonnx; + +fn plugin_init(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { + vadonnx::register(plugin)?; + Ok(()) +} + +gst::plugin_define!( + rsvadonnx, + env!("CARGO_PKG_DESCRIPTION"), + plugin_init, + concat!(env!("CARGO_PKG_VERSION"), "-", env!("COMMIT_ID")), + // FIXME: MPL-2.0 is only allowed since 1.18.3 (as unknown) and 1.20 (as known) + "MPL", + env!("CARGO_PKG_NAME"), + env!("CARGO_PKG_NAME"), + env!("CARGO_PKG_REPOSITORY"), + env!("BUILD_REL_DATE") +); diff --git a/audio/vadonnx/src/vadonnx/imp.rs b/audio/vadonnx/src/vadonnx/imp.rs new file mode 100644 index 0000000000000000000000000000000000000000..1c919172541cc3772c369633f65629ec7ad87fd9 --- /dev/null +++ b/audio/vadonnx/src/vadonnx/imp.rs @@ -0,0 +1,481 @@ +#![deny(warnings)] + +use gst::glib; +use gst::prelude::*; +use gst::subclass::prelude::*; +use gst_base::subclass::prelude::*; +use std::borrow::BorrowMut; +use std::fmt::Debug; +use std::usize; +use tract_onnx::prelude::*; + +use byte_slice_cast::AsSliceOf; +use gst::{debug, element_error, error, error_msg, info, trace, warning}; +use std::sync::Mutex; + +use gst::glib::subclass::Signal; +use once_cell::sync::Lazy; + +static CAT: Lazy = Lazy::new(|| { + gst::DebugCategory::new("rsvadonnx", gst::DebugColorFlags::empty(), Some("Rust VAD")) +}); + +const DEFAULT_VOICE_DETECTION_LIKELIHOOD: f32 = 0.5; +const DEFAULT_VOICE_DETECTION_FRAME_SIZE: usize = 1536; +const DEFAULT_MIN_SILENCE_DURATION: gst::ClockTime = gst::ClockTime::from_mseconds(100); +const DEFAULT_WEBRTC_DSP_MODE: bool = true; + +#[derive(Debug)] +struct Settings { + model: Option, + voice_detection_likelihood: f32, + voice_detection_frame_size: usize, + min_silence_duration: gst::ClockTime, + webrtc_dsp_mode: bool, +} + +impl Default for Settings { + fn default() -> Self { + Settings { + model: None, + voice_detection_likelihood: DEFAULT_VOICE_DETECTION_LIKELIHOOD, + voice_detection_frame_size: DEFAULT_VOICE_DETECTION_FRAME_SIZE, + min_silence_duration: DEFAULT_MIN_SILENCE_DURATION, + webrtc_dsp_mode: DEFAULT_WEBRTC_DSP_MODE, + } + } +} + +type Model = RunnableModel, Graph>>; + +struct State { + runnable_model: Option, + hn: Option, + cn: Option, + remaining_buffers: Option>, + stream_has_voice: bool, + time_elapsed_without_voice: gst::ClockTime, +} + +impl Default for State { + fn default() -> State { + Self { + runnable_model: None, + hn: Some(Tensor::zero::(&[2, 1, 64]).unwrap()), + cn: Some(Tensor::zero::(&[2, 1, 64]).unwrap()), + remaining_buffers: None, + stream_has_voice: false, + time_elapsed_without_voice: gst::ClockTime::ZERO, + } + } +} + +#[derive(Default)] +pub struct VADOnnx { + settings: Mutex, + state: Mutex, +} + +impl VADOnnx { + fn load_model( + &self, + _element: &super::VADOnnx, + model_path: &str, + voice_detection_frame_size: usize, + ) -> TractResult { + onnx() + // load the model + .model_for_path(model_path)? + // specify input type and shape + .with_input_names(["input", "h0", "c0"])? + .with_input_fact( + 0, + InferenceFact::dt_shape(f32::datum_type(), tvec!(1, voice_detection_frame_size)), + )? + .with_input_fact( + 1, + InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)), + )? + .with_input_fact( + 2, + InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)), + )? + .with_output_names(["output", "hn", "cn"])? + .with_output_fact( + 0, + InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 2, 1)), + )? + .with_output_fact( + 1, + InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)), + )? + .with_output_fact( + 2, + InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)), + )? + // optimize the model + .into_optimized()? + // make the model runnable and fix its inputs and outputs + .into_runnable() + } + + fn get_speech_probability( + &self, + element: &super::VADOnnx, + buffer: &[i16], + ) -> anyhow::Result { + let mut state = self.state.lock().unwrap(); + + let h0 = state.hn.take().unwrap(); + let c0 = state.cn.take().unwrap(); + + let model = state.runnable_model.as_ref(); + let model = model.unwrap(); + + let input: Tensor = + tract_ndarray::Array2::::from_shape_fn((1, buffer.len()), |(_, i)| { + // SAFETY: Length is checked already + unsafe { *buffer.get_unchecked(i) as f32 } + }) + .into(); + + let mut result = model.run(tvec!(input, h0, c0))?; + let speech_probability = result[0].as_slice::()?[1]; + debug!(CAT, obj: element, "probability {:?}", speech_probability); + state.cn = Some(result.remove(2).into_tensor()); + state.hn = Some(result.remove(1).into_tensor()); + Ok(speech_probability) + } + fn webrtc_dsp_mode(&self, element: &super::VADOnnx, stream_has_voice: bool) { + //TODO: use stream-time + //TODO: check for a way to add voice-activity in buffer meta + debug!( + CAT, + obj: element, + "Posting voice activity message, stream {} voice", + if stream_has_voice { + "now has" + } else { + "no longer has" + } + ); + let s = gst::Structure::new("voice-activity", &[("stream-has-voice", &stream_has_voice)]); + let msg = gst::message::Element::new(s); + let _ = element.post_message(msg); + } +} + +#[glib::object_subclass] +impl ObjectSubclass for VADOnnx { + const NAME: &'static str = "RsVADOnnx"; + type Type = super::VADOnnx; + type ParentType = gst_base::BaseTransform; +} + +impl ObjectImpl for VADOnnx { + fn properties() -> &'static [glib::ParamSpec] { + static PROPERTIES: Lazy> = Lazy::new(|| { + vec![ + glib::ParamSpecString::new( + "model", + "Model", + "Silero VAD model(.onnx) must be set only before start", + None, + glib::ParamFlags::READWRITE, + ), + glib::ParamSpecFloat::new( + "voice-detection-likelihood", + "Voice Detection Likelihood", + "Voice Detection Likelihood", + 0.0, + 1.0, + DEFAULT_VOICE_DETECTION_LIKELIHOOD, + glib::ParamFlags::READWRITE, + ), + glib::ParamSpecUInt64::new( + "voice-detection-frame-size-bytes", + "Voice Detection Frame Size (bytes)", + "Voice Detection Frame Size (bytes) must be set only before start", + 0, + u64::MAX, + DEFAULT_VOICE_DETECTION_FRAME_SIZE as _, + glib::ParamFlags::READWRITE, + ), + glib::ParamSpecUInt64::new( + "min-silence-duration-ms", + "Minimum Silence Duration in Milliseconds", + "Silence duration helps overcome below VAD likelihood values between speech", + 100, + u64::MAX, + DEFAULT_MIN_SILENCE_DURATION.mseconds(), + glib::ParamFlags::READWRITE, + ), + ] + }); + PROPERTIES.as_ref() + } + fn signals() -> &'static [Signal] { + static SIGNALS: Lazy> = Lazy::new(|| { + vec![Signal::builder( + "voice-activity", + &[bool::static_type().into()], + glib::Type::UNIT.into(), + ) + .build()] + }); + SIGNALS.as_ref() + } + + fn set_property( + &self, + obj: &Self::Type, + _id: usize, + value: &glib::Value, + pspec: &glib::ParamSpec, + ) { + let mut settings = self.settings.lock().unwrap(); + match pspec.name() { + "model" => { + let model = value.get().expect("type checked upstream"); + settings.model = model; + } + "voice-detection-likelihood" => { + let voice_detection_likelihood: f32 = value.get().expect("type checked upstream"); + info!( + CAT, + obj: obj, + "Changing voice_detection_likelihood from {} to {}", + settings.voice_detection_likelihood, + voice_detection_likelihood + ); + settings.voice_detection_likelihood = voice_detection_likelihood; + } + "voice-detection-frame-size-bytes" => { + let voice_detection_frame_size: u64 = value.get().expect("type checked upstream"); + info!( + CAT, + obj: obj, + "Changing window-size-samples from {} to {}", + settings.voice_detection_frame_size, + voice_detection_frame_size + ); + settings.voice_detection_frame_size = voice_detection_frame_size as _; + } + "min-silence-duration-ms" => { + let min_silence_duration: u64 = value.get().expect("type checked upstream"); + info!( + CAT, + obj: obj, + "Changing min-silence-duration-ms from {} to {}", + settings.min_silence_duration, + min_silence_duration + ); + settings.min_silence_duration = gst::ClockTime::from_mseconds(min_silence_duration); + } + "webrtc-dsp-mode" => { + let webrtc_dsp_mode: bool = value.get().expect("type checked upstream"); + info!( + CAT, + obj: obj, + "Changing webrtc-dsp-mode from {} to {}", + settings.webrtc_dsp_mode, + webrtc_dsp_mode + ); + settings.webrtc_dsp_mode = webrtc_dsp_mode; + } + _ => unimplemented!(), + } + } + fn property(&self, _obj: &Self::Type, _id: usize, pspec: &glib::ParamSpec) -> glib::Value { + let settings = self.settings.lock().unwrap(); + match pspec.name() { + "model" => settings.model.to_value(), + "voice_detection_likelihood" => settings.voice_detection_likelihood.to_value(), + "window-size-samples" => (settings.voice_detection_frame_size as u64).to_value(), + "min-silence-duration-ms" => settings.min_silence_duration.mseconds().to_value(), + "webrtc-dsp-mode" => settings.webrtc_dsp_mode.to_value(), + _ => unimplemented!(), + } + } +} + +impl GstObjectImpl for VADOnnx {} + +impl ElementImpl for VADOnnx { + fn metadata() -> Option<&'static gst::subclass::ElementMetadata> { + static ELEMENT_METADATA: Lazy = Lazy::new(|| { + gst::subclass::ElementMetadata::new( + "Voice Activity Detection", + "Filter/Effect/Audio", + "VAD based on Silero ONNX model and Tract ONNX Runtime", + "Abdul Rehman <7r3nzy@gmail.com>", + ) + }); + + Some(&*ELEMENT_METADATA) + } + fn change_state( + &self, + element: &Self::Type, + transition: gst::StateChange, + ) -> Result { + trace!(CAT, obj: element, "Changing state {:?}", transition); + self.parent_change_state(element, transition) + } + fn pad_templates() -> &'static [gst::PadTemplate] { + static PAD_TEMPLATES: Lazy> = Lazy::new(|| { + let caps = gst::Caps::new_simple( + "audio/x-raw", + &[ + ("format", &"S16LE"), + ("rate", &gst::List::new(&[&16000])), + ("channels", &1), + ("layout", &"interleaved"), + ], + ); + let sink_pad_template = gst::PadTemplate::new( + "sink", + gst::PadDirection::Sink, + gst::PadPresence::Always, + &caps, + ) + .unwrap(); + + let src_pad_template = gst::PadTemplate::new( + "src", + gst::PadDirection::Src, + gst::PadPresence::Always, + &caps, + ) + .unwrap(); + + vec![src_pad_template, sink_pad_template] + }); + + PAD_TEMPLATES.as_ref() + } +} + +impl BaseTransformImpl for VADOnnx { + const MODE: gst_base::subclass::BaseTransformMode = + gst_base::subclass::BaseTransformMode::AlwaysInPlace; + const PASSTHROUGH_ON_SAME_CAPS: bool = false; + const TRANSFORM_IP_ON_PASSTHROUGH: bool = false; + + fn start(&self, element: &Self::Type) -> Result<(), gst::ErrorMessage> { + let settings = self.settings.lock().unwrap(); + + let model_path = settings + .model + .as_ref() + .ok_or_else(|| error_msg!(gst::ResourceError::Settings, ["VADOnnx not started"]))?; + + let voice_detection_frame_size = settings.voice_detection_frame_size; + let model = self + .load_model(element, model_path, voice_detection_frame_size) + .map_err(|e| { + error!(CAT, obj: element, "Error {:?}", e); + error_msg!(gst::ResourceError::NotFound, ["Failed to open model"]) + })?; + + *self.state.lock().unwrap() = State { + runnable_model: Some(model), + ..Default::default() + }; + + Ok(()) + } + + fn stop(&self, element: &Self::Type) -> Result<(), gst::ErrorMessage> { + trace!(CAT, obj: element, "Stopping"); + **self.state.lock().as_mut().unwrap() = State::default(); + trace!(CAT, obj: element, "Stopped"); + Ok(()) + } + + fn transform_ip( + &self, + element: &Self::Type, + buf: &mut gst::BufferRef, + ) -> Result { + trace!(CAT, obj: element, "Transforming {:?}", buf); + { + let map = buf.map_readable().map_err(|_| { + element_error!(element, gst::CoreError::Failed, ["Failed to map buffer"]); + gst::FlowError::Error + })?; + + let slice = map.as_slice_of::().map_err(|e| { + element_error!( + element, + gst::CoreError::Failed, + [format!("Failed to get slice: {}", e).as_str()] + ); + gst::FlowError::Error + })?; + + let mut remaining_buffers = self.state.lock().unwrap().remaining_buffers.take(); + let slice = if let Some(remaining_buffers) = &mut remaining_buffers { + remaining_buffers.extend_from_slice(slice); + remaining_buffers.as_slice() + } else { + slice + }; + trace!(CAT, obj: element, "{:?}", slice.len()); + let voice_detection_frame_size = + self.settings.lock().unwrap().voice_detection_frame_size; + let voice_detection_likelihood = + self.settings.lock().unwrap().voice_detection_likelihood; + let webrtc_dsp_mode = self.settings.lock().unwrap().webrtc_dsp_mode; + let mut chunks = slice.chunks_exact(voice_detection_frame_size); + for chunk in chunks.by_ref() { + let speech_probability = + self.get_speech_probability(element, chunk).map_err(|e| { + element_error!( + element, + gst::CoreError::Failed, + [format!("Failed to get speech probability, error: {}", e).as_str()] + ); + gst::FlowError::Error + })?; + let stream_has_voice = speech_probability >= voice_detection_likelihood; + + let self_stream_has_voice = self.state.lock().unwrap().stream_has_voice; + + if webrtc_dsp_mode { + if self_stream_has_voice != stream_has_voice { + self.webrtc_dsp_mode(element, stream_has_voice); + } + } else { + //TODO: Implement for min_silence_duration (in progress) + //this is incomplete + let min_silence_duration = self.settings.lock().unwrap().min_silence_duration; + let time_elapsed_without_voice = + self.state.lock().unwrap().time_elapsed_without_voice; + if time_elapsed_without_voice >= min_silence_duration && !stream_has_voice { + element.emit_by_name::<()>("voice-activity", &[&stream_has_voice]); + } + } + self.state.lock().unwrap().stream_has_voice = stream_has_voice; + } + + let chunk = chunks.remainder(); + if !chunk.is_empty() { + warning!(CAT, obj: element, "remaining buffer size {} is smaller than voice detection frame {}, will be combined with next buffer", chunk.len(), voice_detection_frame_size); + self.state.lock().unwrap().remaining_buffers = Some(chunk.to_vec()); + } + } + let meta = buf.meta::(); + if let Some(meta) = meta { + //TODO: update existing voice activity + meta.voice_activity(); + } else { + //FIXME: If any chunk of the buffer contains activity, + // whole buffer should be marked as buffer has voice activity + + let stream_has_voice = self.state.lock().unwrap().stream_has_voice; + gst_audio::AudioLevelMeta::add(buf.borrow_mut(), 0, stream_has_voice); + } + Ok(gst::FlowSuccess::Ok) + } +} diff --git a/audio/vadonnx/src/vadonnx/mod.rs b/audio/vadonnx/src/vadonnx/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..d7f391f2f10d87b3107b50e122ce962eabd61022 --- /dev/null +++ b/audio/vadonnx/src/vadonnx/mod.rs @@ -0,0 +1,17 @@ +use gst::glib; +use gst::prelude::*; + +mod imp; + +glib::wrapper! { + pub struct VADOnnx(ObjectSubclass) @extends gst_base::BaseTransform, gst::Element, gst::Object; +} + +pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { + gst::Element::register( + Some(plugin), + "rsvadonnx", + gst::Rank::None, + VADOnnx::static_type(), + ) +} diff --git a/audio/vadonnx/tests/vadonnx.rs b/audio/vadonnx/tests/vadonnx.rs new file mode 100644 index 0000000000000000000000000000000000000000..412aa167fd047f34a261184a0680c4cdd737e9f9 --- /dev/null +++ b/audio/vadonnx/tests/vadonnx.rs @@ -0,0 +1,133 @@ +// Copyright (C) 2022 Abdul Rehman <7r3nzy@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. +// If a copy of the MPL was not distributed with this file, You can obtain one at +// . +// +// SPDX-License-Identifier: MPL-2.0 + +use gst::prelude::*; + +macro_rules! verify_wav_spec { + ($spec:expr) => { + if $spec.channels != 1 { + panic!("VAD test wav file doesn't contain mono channel"); + } + + if $spec.sample_rate != 16000 { + panic!("VAD test wav file sample rate is not 16000"); + } + + if $spec.bits_per_sample != 16 { + panic!("VAD test wav file bits per sample is not 16"); + } + + if $spec.sample_format != hound::SampleFormat::Int { + panic!("VAD test wav file sample format is not Int"); + } + }; +} + +macro_rules! read_samples { + ($reader:expr) => { + $reader + .into_samples::() + .map(|s| s.expect("Failed to read samples from testwav file")) + .collect::>() + }; +} + +macro_rules! prepare_file { + ($file_name:expr) => {{ + let input_path = { + let mut r = std::path::PathBuf::new(); + r.push(env!("CARGO_MANIFEST_DIR")); + r.push("tests"); + r.push("test_files"); + r.push($file_name); + r + }; + + if !input_path.exists() { + panic!("VAD test file {} not found", input_path.display()); + } + + input_path + }}; +} + +macro_rules! prepare_wav_file { + ($file_name:expr) => {{ + let input_path = prepare_file!($file_name); + + let reader = hound::WavReader::open(&input_path).expect("Failed to read test wav file"); + + let spec = reader.spec(); + verify_wav_spec!(spec); + let samples = read_samples!(reader); + + samples + }}; +} + +fn init() { + use std::sync::Once; + static INIT: Once = Once::new(); + + INIT.call_once(|| { + gst::init().unwrap(); + gstrsvadonnx::plugin_register_static().expect("Failed to register vadonnx plugin"); + }); +} + +fn build_harness() -> (gst_check::Harness, gst::Element) { + let vad = gst::ElementFactory::make("rsvadonnx", None).unwrap(); + vad.set_property( + "model", + prepare_file!("silero_vad.onnx") + .to_str() + .expect("Failed to convert path to str"), + ); + + let mut h = gst_check::Harness::with_element(&vad, Some("sink"), Some("src")); + let caps = gst::Caps::builder("audio/x-raw") + .field("format", "S16LE") + .field("rate", 16000i32) + .field("channels", 1i32) + .field("layout", "interleaved") + .build(); + + h.set_caps(caps.clone(), caps); + + (h, vad) +} + +#[test] +fn test_vadonnx_speech_trigger_with_signal() { + init(); + + let (mut h, _) = build_harness(); + h.play(); + let _ = prepare_wav_file!("test.wav"); +} + +#[test] +fn test_vadonnx_speech_trigger_with_event() { + init(); + let (mut h, _) = build_harness(); + h.play(); +} + +#[test] +fn test_vadonnx_speech_trigger_with_message() { + init(); + let (mut h, _) = build_harness(); + h.play(); +} + +#[test] +fn test_vadonnx_speech_trigger_with_all() { + init(); + let (mut h, _) = build_harness(); + h.play(); +} diff --git a/ci/utils.py b/ci/utils.py index 7775355164b30310a39be9ed1c9487cfa7185e34..1a41d9cfbcafa2a535eac0b86284978b2b9bd0f4 100644 --- a/ci/utils.py +++ b/ci/utils.py @@ -3,7 +3,7 @@ import os DIRS = ['audio', 'generic', 'net', 'text', 'utils', 'video'] # Plugins whose name is prefixed by 'rs' RS_PREFIXED = ['audiofx', 'closedcaption', - 'dav1d', 'file', 'json', 'onvif', 'regex', 'webp'] + 'dav1d', 'file', 'json', 'onvif', 'regex', 'webp', 'vadonnx'] OVERRIDE = {'wrap': 'rstextwrap', 'flavors': 'rsflv', 'ahead': 'textahead'} diff --git a/meson.build b/meson.build index 1cc88311a7d38e3710b3a21c2e8c4f095d36b8e8..dfa917c556167a09f83bd5f803d53f66f486ba53 100644 --- a/meson.build +++ b/meson.build @@ -62,6 +62,7 @@ plugins = { 'gst-plugin-spotify': 'libgstspotify', 'gst-plugin-textahead': 'libgsttextahead', 'gst-plugin-onvif': 'libgstrsonvif', + 'gst-plugin-vadonnx': 'libgstrsvadonnx', } extra_env = {}