From 90a0197f7cea6ec7a0407caa8c6e24ece0310acf Mon Sep 17 00:00:00 2001 From: Rafael Caricio Date: Sun, 17 Apr 2022 11:44:41 +0200 Subject: [PATCH] Add image comparison element New image comparison element, find images in the stream and post metadata of found image in the pipeline. --- video/videofx/Cargo.toml | 4 +- video/videofx/src/imgcmp/imp.rs | 479 ++++++++++++++++++++++++++++++++ video/videofx/src/imgcmp/mod.rs | 89 ++++++ video/videofx/src/lib.rs | 4 +- 4 files changed, 574 insertions(+), 2 deletions(-) create mode 100644 video/videofx/src/imgcmp/imp.rs create mode 100644 video/videofx/src/imgcmp/mod.rs diff --git a/video/videofx/Cargo.toml b/video/videofx/Cargo.toml index b56c9e0f..8145f282 100644 --- a/video/videofx/Cargo.toml +++ b/video/videofx/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "gst-plugin-videofx" version = "0.9.0" -authors = ["Sanchayan Maity "] +authors = ["Sanchayan Maity ", "Rafael Caricio "] repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs" license = "MPL-2.0" description = "Video Effects Plugin" @@ -13,6 +13,8 @@ atomic_refcell = "0.1" once_cell = "1.0" color-thief = "0.2.2" color-name = "1.0.0" +image = "0.24.2" +image_hasher = "1.0.0" [dependencies.gst] git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" diff --git a/video/videofx/src/imgcmp/imp.rs b/video/videofx/src/imgcmp/imp.rs new file mode 100644 index 00000000..99c2eb88 --- /dev/null +++ b/video/videofx/src/imgcmp/imp.rs @@ -0,0 +1,479 @@ +// Copyright (C) 2022 Rafael Caricio +// +// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. +// If a copy of the MPL was not distributed with this file, You can obtain one at +// . +// +// SPDX-License-Identifier: MPL-2.0 + +use crate::imgcmp::{HashAlgorithm, ImageDetected}; +use gst::subclass::prelude::*; +use gst::{glib, glib::prelude::*, prelude::*, BufferRef}; +use gst_base::prelude::*; +use gst_video::subclass::prelude::*; +use gst_video::VideoFormat; +use image::ColorType; +use image_hasher::{Hasher, HasherConfig, ImageHash}; +use once_cell::sync::Lazy; +use std::sync::{Arc, Mutex}; + +static CAT: Lazy = Lazy::new(|| { + gst::DebugCategory::new( + "imgcmp", + gst::DebugColorFlags::empty(), + Some("Image comparison"), + ) +}); + +const DEFAULT_HASH_ALGO: HashAlgorithm = HashAlgorithm::Blockhash; +const DEFAULT_MAX_DISTANCE_THRESHOLD: u32 = 0; + +struct Settings { + location: Option, + hash_algo: HashAlgorithm, + max_distance_threshold: u32, +} + +impl Default for Settings { + fn default() -> Self { + Settings { + location: None, + hash_algo: DEFAULT_HASH_ALGO, + max_distance_threshold: DEFAULT_MAX_DISTANCE_THRESHOLD, + } + } +} + +struct State { + reference_image: ImageHash, + hasher: Hasher, + out_color_type: ColorType, + out_info: gst_video::VideoInfo, +} + +unsafe impl Send for State {} + +impl State { + fn new( + hash_algo: HashAlgorithm, + image_location: &str, + out_color_type: ColorType, + out_info: gst_video::VideoInfo, + ) -> Result { + let hasher = HasherConfig::new().hash_alg(hash_algo.into()).to_hasher(); + + // Load all images using the same hashing algorithm + let image = image::open(image_location).map_err(|err| { + gst::loggable_error!(CAT, "Failed to load image from location: {}", err) + })?; + + Ok(Self { + reference_image: hasher.hash_image(&image), + hasher, + out_color_type, + out_info, + }) + } +} + +#[derive(Default, Clone)] +pub struct ImageCompare { + settings: Arc>, + state: Arc>>, +} + +impl ImageCompare { + fn compare_frame( + &self, + element: &super::ImageCompare, + buf: &mut gst::BufferRef, + ) -> Result, gst::FlowError> { + let state_guard = self.state.lock().unwrap(); + let state = state_guard.as_ref().ok_or_else(|| { + gst::error!(CAT, obj: element, "Element state not created"); + gst::FlowError::Error + })?; + + let frame = self.hash_frame( + element, + buf, + &state.hasher, + &state.out_info, + state.out_color_type, + )?; + + gst::debug!( + CAT, + obj: element, + "Loaded current buffer, base64: {}", + frame.to_base64() + ); + + let max_distance_threshold = { + let settings = self.settings.lock().unwrap(); + settings.max_distance_threshold + }; + + let segment = element.segment().downcast::().ok(); + let running_time = segment + .as_ref() + .and_then(|s| s.to_running_time(buf.pts().unwrap())); + + let distance = state.reference_image.dist(&frame); + if distance <= max_distance_threshold { + Ok(Some(ImageDetected::new( + distance, + state.reference_image.to_base64(), + running_time, + buf.pts(), + ))) + } else { + gst::debug!( + CAT, + obj: element, + "Compared image {} distance is {} but below threshold {}", + state.reference_image.to_base64(), + distance, + max_distance_threshold, + ); + Ok(None) + } + } + + fn hash_frame( + &self, + element: &super::ImageCompare, + buf: &mut BufferRef, + hasher: &Hasher, + video_info: &gst_video::VideoInfo, + out_color_type: ColorType, + ) -> Result { + let frame = gst_video::VideoFrameRef::from_buffer_ref_readable(buf, video_info).unwrap(); + + let height = frame.height(); + let width = frame.width(); + let buf = tightly_packed_framebuffer(&frame); + + match out_color_type { + ColorType::Rgb8 => { + let frame_buf = image::RgbImage::from_raw(width, height, buf).ok_or_else(|| { + gst::error!(CAT, obj: element, "Could not load gst buffer as RGB image"); + gst::FlowError::Error + })?; + + Ok(hasher.hash_image(&frame_buf)) + } + ColorType::Rgba8 => { + let frame_buf = + image::RgbaImage::from_raw(width, height, buf).ok_or_else(|| { + gst::error!(CAT, obj: element, "Could not load gst buffer as RGBA image"); + gst::FlowError::Error + })?; + + Ok(hasher.hash_image(&frame_buf)) + } + non_supported_color => { + gst::error!( + CAT, + obj: element, + "Color type {:?} is not supported", + non_supported_color + ); + Err(gst::FlowError::NotSupported) + } + } + } + + fn publish_image_detection(&self, element: &super::ImageCompare, info: &ImageDetected) { + gst::info!(CAT, obj: element, "Image detected {:?}", info); + element + .post_message( + gst::message::Element::builder( + gst::structure::Structure::builder("image-detected") + .field("distance", info.distance) + .field("image-base64-hash", info.image_base64_hash.clone()) + .field("running-time", info.running_time) + .field("pts", info.pts) + .build(), + ) + .src(element) + .build(), + ) + .expect("Failed to publish message to bus") + } +} + +#[glib::object_subclass] +impl ObjectSubclass for ImageCompare { + const NAME: &'static str = "GstImageCompare"; + type Type = super::ImageCompare; + type ParentType = gst_base::BaseTransform; +} + +impl ObjectImpl for ImageCompare { + fn properties() -> &'static [glib::ParamSpec] { + static PROPERTIES: Lazy> = Lazy::new(|| { + vec![ + glib::ParamSpecString::new( + "location", + "Image Location", + "Filesystem location of image to be compared to video frames", + None, + glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, + ), + glib::ParamSpecEnum::new( + "hash-algo", + "Hashing Algorithm", + "Which hashing algorithm to use for image comparisons", + HashAlgorithm::static_type(), + DEFAULT_HASH_ALGO as i32, + glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, + ), + glib::ParamSpecUInt::new( + "max-dist-threshold", + "Maximum Distance Threshold", + "Maximum distance threshold to emit messages when an image is detected, by default emits only on exact match", + 0, + u32::MAX, + DEFAULT_MAX_DISTANCE_THRESHOLD, + glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_PLAYING, + ), + ] + }); + + PROPERTIES.as_ref() + } + + fn set_property( + &self, + obj: &Self::Type, + _id: usize, + value: &glib::Value, + pspec: &glib::ParamSpec, + ) { + match pspec.name() { + "location" => { + let mut settings = self.settings.lock().unwrap(); + let image_location = value.get().expect("type checked upstream"); + if settings.location != image_location { + gst::info!( + CAT, + obj: obj, + "Changing location from {:?} to {:?}", + settings.location, + image_location + ); + settings.location = image_location; + } + } + "hash-algo" => { + let mut settings = self.settings.lock().unwrap(); + let hash_algo = value.get().expect("type checked upstream"); + if settings.hash_algo != hash_algo { + gst::info!( + CAT, + obj: obj, + "Changing hash-algo from {:?} to {:?}", + settings.hash_algo, + hash_algo + ); + settings.hash_algo = hash_algo; + } + } + "max-dist-threshold" => { + let mut settings = self.settings.lock().unwrap(); + let max_distance_threshold = value.get().expect("type checked upstream"); + if settings.max_distance_threshold != max_distance_threshold { + gst::info!( + CAT, + obj: obj, + "Changing max-dist-threshold from {} to {}", + settings.max_distance_threshold, + max_distance_threshold + ); + settings.max_distance_threshold = max_distance_threshold; + } + } + _ => unimplemented!(), + } + } + + fn property(&self, _obj: &Self::Type, _id: usize, pspec: &glib::ParamSpec) -> glib::Value { + match pspec.name() { + "location" => { + let settings = self.settings.lock().unwrap(); + settings.location.to_value() + } + "hash-algo" => { + let settings = self.settings.lock().unwrap(); + settings.hash_algo.to_value() + } + "max-dist-threshold" => { + let settings = self.settings.lock().unwrap(); + settings.max_distance_threshold.to_value() + } + _ => unimplemented!(), + } + } +} + +impl GstObjectImpl for ImageCompare {} + +impl ElementImpl for ImageCompare { + fn metadata() -> Option<&'static gst::subclass::ElementMetadata> { + static ELEMENT_METADATA: Lazy = Lazy::new(|| { + gst::subclass::ElementMetadata::new( + "Image comparison", + "Filter/Video", + "Detects images in a video", + "Rafael Caricio ", + ) + }); + + Some(&*ELEMENT_METADATA) + } + + fn pad_templates() -> &'static [gst::PadTemplate] { + static PAD_TEMPLATES: Lazy> = Lazy::new(|| { + let formats = gst::List::new([VideoFormat::Rgb.to_str(), VideoFormat::Rgba.to_str()]); + + let caps = gst::Caps::builder("video/x-raw") + .field("format", &formats) + .field("width", gst::IntRange::new(1, i32::MAX)) + .field("height", gst::IntRange::new(1, i32::MAX)) + .field( + "framerate", + gst::FractionRange::new( + gst::Fraction::new(0, 1), + gst::Fraction::new(i32::MAX, 1), + ), + ) + .build(); + + let sink_pad_template = gst::PadTemplate::new( + "sink", + gst::PadDirection::Sink, + gst::PadPresence::Always, + &caps, + ) + .unwrap(); + + let src_pad_template = gst::PadTemplate::new( + "src", + gst::PadDirection::Src, + gst::PadPresence::Always, + &caps, + ) + .unwrap(); + + vec![sink_pad_template, src_pad_template] + }); + + PAD_TEMPLATES.as_ref() + } +} + +impl BaseTransformImpl for ImageCompare { + const MODE: gst_base::subclass::BaseTransformMode = + gst_base::subclass::BaseTransformMode::AlwaysInPlace; + const PASSTHROUGH_ON_SAME_CAPS: bool = false; + const TRANSFORM_IP_ON_PASSTHROUGH: bool = false; + + fn set_caps( + &self, + element: &Self::Type, + incaps: &gst::Caps, + outcaps: &gst::Caps, + ) -> Result<(), gst::LoggableError> { + let in_info = match gst_video::VideoInfo::from_caps(incaps) { + Err(_) => return Err(gst::loggable_error!(CAT, "Failed to parse input caps")), + Ok(info) => info, + }; + + let out_info = match gst_video::VideoInfo::from_caps(outcaps) { + Err(_) => return Err(gst::loggable_error!(CAT, "Failed to parse output caps")), + Ok(info) => info, + }; + + gst::debug!( + CAT, + obj: element, + "Configured for caps {} to {}", + incaps, + outcaps + ); + + let color_format = match in_info.format() { + VideoFormat::Rgb => ColorType::Rgb8, + VideoFormat::Rgba => ColorType::Rgba8, + video_fmt => { + return Err(gst::loggable_error!( + CAT, + "Not supported video format: {}", + video_fmt + )) + } + }; + + let settings = self.settings.lock().unwrap(); + let image_location = settings.location.as_ref().ok_or_else(|| { + gst::loggable_error!(CAT, "Location of the reference image to compare not defined, set the property \"location\" to a valid file") + })?; + let mut state = self.state.lock().unwrap(); + *state = Some(State::new( + settings.hash_algo, + image_location, + color_format, + out_info, + )?); + + gst::debug!( + CAT, + obj: element, + "Loaded reference image hash: {:?}", + state.as_ref().unwrap().reference_image.to_base64() + ); + + Ok(()) + } + + fn transform_ip( + &self, + element: &Self::Type, + buf: &mut gst::BufferRef, + ) -> Result { + if let Some(detection) = self.compare_frame(element, buf)? { + self.publish_image_detection(element, &detection) + } + Ok(gst::FlowSuccess::Ok) + } + + fn propose_allocation( + &self, + element: &Self::Type, + decide_query: Option<&gst::query::Allocation>, + query: &mut gst::query::Allocation, + ) -> Result<(), gst::LoggableError> { + query.add_allocation_meta::(None); + self.parent_propose_allocation(element, decide_query, query) + } +} + +/// Helper method that takes a gstreamer video-frame and copies it into a +/// tightly packed rgb(a) buffer, ready for consumption. +fn tightly_packed_framebuffer(frame: &gst_video::VideoFrameRef<&gst::BufferRef>) -> Vec { + assert_eq!(frame.n_planes(), 1); // RGB and RGBA are tightly packed + let line_size = (frame.width() * frame.n_components()) as usize; + let line_stride = frame.plane_stride()[0] as usize; + let mut raw_frame = Vec::with_capacity(line_size * frame.info().height() as usize); + + // copy gstreamer frame to tightly packed rgb(a) frame. + frame + .plane_data(0) + .unwrap() + .chunks_exact(line_stride) + .map(|padded_line| &padded_line[..line_size]) + .for_each(|line| raw_frame.extend_from_slice(line)); + + raw_frame +} diff --git a/video/videofx/src/imgcmp/mod.rs b/video/videofx/src/imgcmp/mod.rs new file mode 100644 index 00000000..4dcd343e --- /dev/null +++ b/video/videofx/src/imgcmp/mod.rs @@ -0,0 +1,89 @@ +// Copyright (C) 2022 Rafael Caricio +// +// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. +// If a copy of the MPL was not distributed with this file, You can obtain one at +// . +// +// SPDX-License-Identifier: MPL-2.0 + +use gst::prelude::*; +use gst::{glib, ClockTime}; +use image_hasher::HashAlg; + +mod imp; + +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)] +#[repr(u32)] +#[enum_type(name = "GstImageCompareHashAlgorithm")] +#[non_exhaustive] +pub enum HashAlgorithm { + #[enum_value(name = "Mean: The Mean hashing algorithm.", nick = "mean")] + Mean = 0, + #[enum_value(name = "Gradient: The Gradient hashing algorithm.", nick = "gradient")] + Gradient = 1, + #[enum_value( + name = "VertGradient: The Vertical-Gradient hashing algorithm.", + nick = "vertgradient" + )] + VertGradient = 2, + #[enum_value( + name = "DoubleGradient: The Double-Gradient hashing algorithm.", + nick = "doublegradient" + )] + DoubleGradient = 3, + #[enum_value( + name = "Blockhash: The [Blockhash](https://github.com/commonsmachinery/blockhash-rfc) algorithm.", + nick = "blockhash" + )] + Blockhash = 4, +} + +impl From for HashAlg { + fn from(ha: HashAlgorithm) -> Self { + use HashAlgorithm::*; + match ha { + Mean => Self::Mean, + Gradient => Self::Gradient, + VertGradient => Self::VertGradient, + DoubleGradient => Self::DoubleGradient, + Blockhash => Self::Blockhash, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ImageDetected { + distance: u32, + image_base64_hash: String, + running_time: Option, + pts: Option, +} + +impl ImageDetected { + pub fn new( + distance: u32, + image_base64_hash: String, + running_time: Option, + pts: Option, + ) -> Self { + Self { + distance, + image_base64_hash, + running_time, + pts, + } + } +} + +glib::wrapper! { + pub struct ImageCompare(ObjectSubclass) @extends gst_base::BaseTransform, gst::Element, gst::Object; +} + +pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { + gst::Element::register( + Some(plugin), + "imgcmp", + gst::Rank::None, + ImageCompare::static_type(), + ) +} diff --git a/video/videofx/src/lib.rs b/video/videofx/src/lib.rs index 2b8d4c5b..4e013d50 100644 --- a/video/videofx/src/lib.rs +++ b/video/videofx/src/lib.rs @@ -10,10 +10,12 @@ mod border; mod colordetect; +mod imgcmp; fn plugin_init(plugin: &gst::Plugin) -> Result<(), gst::glib::BoolError> { border::register(plugin)?; - colordetect::register(plugin) + colordetect::register(plugin)?; + imgcmp::register(plugin) } gst::plugin_define!( -- GitLab