Verified Commit fe4a5fda authored by Rafael Carício's avatar Rafael Carício 🏠
Browse files

Add image comparison element

New image comparison element, find images in the stream and post
metadata of found image in the pipeline.
parent 5c00db62
Pipeline #623098 passed with stages
in 24 minutes and 27 seconds
[package]
name = "gst-plugin-videofx"
version = "0.9.0"
authors = ["Sanchayan Maity <sanchayan@asymptotic.io>"]
authors = ["Sanchayan Maity <sanchayan@asymptotic.io>", "Rafael Caricio <rafael@caricio.com>"]
repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs"
license = "MPL-2.0"
description = "Video Effects Plugin"
......@@ -13,6 +13,8 @@ atomic_refcell = "0.1"
once_cell = "1.0"
color-thief = "0.2.2"
color-name = "1.0.0"
image = "0.24.2"
image_hasher = "1.0.0"
[dependencies.gst]
git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs"
......
// Copyright (C) 2022 Rafael Caricio <rafael@caricio.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at
// <https://mozilla.org/MPL/2.0/>.
//
// SPDX-License-Identifier: MPL-2.0
use crate::imgcmp::{HashAlgorithm, ImageDetected};
use gst::subclass::prelude::*;
use gst::{glib, glib::prelude::*, prelude::*, BufferRef};
use gst_base::prelude::*;
use gst_video::subclass::prelude::*;
use gst_video::VideoFormat;
use image::ColorType;
use image_hasher::{Hasher, HasherConfig, ImageHash};
use once_cell::sync::Lazy;
use std::sync::{Arc, Mutex};
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
gst::DebugCategory::new(
"imgcmp",
gst::DebugColorFlags::empty(),
Some("Image comparison"),
)
});
const DEFAULT_HASH_ALGO: HashAlgorithm = HashAlgorithm::Blockhash;
const DEFAULT_MAX_DISTANCE_THRESHOLD: u32 = 0;
struct Settings {
location: Option<String>,
hash_algo: HashAlgorithm,
max_distance_threshold: u32,
}
impl Default for Settings {
fn default() -> Self {
Settings {
location: None,
hash_algo: DEFAULT_HASH_ALGO,
max_distance_threshold: DEFAULT_MAX_DISTANCE_THRESHOLD,
}
}
}
struct State {
reference_image: ImageHash,
hasher: Hasher,
out_color_type: ColorType,
out_info: gst_video::VideoInfo,
}
unsafe impl Send for State {}
impl State {
fn new(
hash_algo: HashAlgorithm,
image_location: &str,
out_color_type: ColorType,
out_info: gst_video::VideoInfo,
) -> Result<Self, gst::LoggableError> {
let hasher = HasherConfig::new().hash_alg(hash_algo.into()).to_hasher();
// Load all images using the same hashing algorithm
let image = image::open(image_location).map_err(|err| {
gst::loggable_error!(CAT, "Failed to load image from location: {}", err)
})?;
Ok(Self {
reference_image: hasher.hash_image(&image),
hasher,
out_color_type,
out_info,
})
}
}
#[derive(Default, Clone)]
pub struct ImageCompare {
settings: Arc<Mutex<Settings>>,
state: Arc<Mutex<Option<State>>>,
}
impl ImageCompare {
fn compare_frame(
&self,
element: &super::ImageCompare,
buf: &mut gst::BufferRef,
) -> Result<Option<ImageDetected>, gst::FlowError> {
let state_guard = self.state.lock().unwrap();
let state = state_guard.as_ref().ok_or_else(|| {
gst::error!(CAT, obj: element, "Element state not created");
gst::FlowError::Error
})?;
let frame = self.hash_frame(
element,
buf,
&state.hasher,
&state.out_info,
state.out_color_type,
)?;
gst::debug!(
CAT,
obj: element,
"Loaded current buffer, base64: {}",
frame.to_base64()
);
let max_distance_threshold = {
let settings = self.settings.lock().unwrap();
settings.max_distance_threshold
};
let segment = element.segment().downcast::<gst::ClockTime>().ok();
let running_time = segment
.as_ref()
.and_then(|s| s.to_running_time(buf.pts().unwrap()));
let distance = state.reference_image.dist(&frame);
if distance <= max_distance_threshold {
Ok(Some(ImageDetected::new(
distance,
state.reference_image.to_base64(),
running_time,
buf.pts(),
)))
} else {
gst::debug!(
CAT,
obj: element,
"Compared image {} distance is {} but below threshold {}",
state.reference_image.to_base64(),
distance,
max_distance_threshold,
);
Ok(None)
}
}
fn hash_frame(
&self,
element: &super::ImageCompare,
buf: &mut BufferRef,
hasher: &Hasher,
video_info: &gst_video::VideoInfo,
out_color_type: ColorType,
) -> Result<ImageHash, gst::FlowError> {
let frame = gst_video::VideoFrameRef::from_buffer_ref_readable(buf, video_info).unwrap();
let height = frame.height();
let width = frame.width();
let buf = tightly_packed_framebuffer(&frame);
match out_color_type {
ColorType::Rgb8 => {
let frame_buf = image::RgbImage::from_raw(width, height, buf).ok_or_else(|| {
gst::error!(CAT, obj: element, "Could not load gst buffer as RGB image");
gst::FlowError::Error
})?;
Ok(hasher.hash_image(&frame_buf))
}
ColorType::Rgba8 => {
let frame_buf =
image::RgbaImage::from_raw(width, height, buf).ok_or_else(|| {
gst::error!(CAT, obj: element, "Could not load gst buffer as RGBA image");
gst::FlowError::Error
})?;
Ok(hasher.hash_image(&frame_buf))
}
non_supported_color => {
gst::error!(
CAT,
obj: element,
"Color type {:?} is not supported",
non_supported_color
);
Err(gst::FlowError::NotSupported)
}
}
}
fn publish_image_detection(&self, element: &super::ImageCompare, info: &ImageDetected) {
gst::info!(CAT, obj: element, "Image detected {:?}", info);
element
.post_message(
gst::message::Element::builder(
gst::structure::Structure::builder("image-detected")
.field("distance", info.distance)
.field("image-base64-hash", info.image_base64_hash.clone())
.field("running-time", info.running_time)
.field("pts", info.pts)
.build(),
)
.src(element)
.build(),
)
.expect("Failed to publish message to bus")
}
}
#[glib::object_subclass]
impl ObjectSubclass for ImageCompare {
const NAME: &'static str = "GstImageCompare";
type Type = super::ImageCompare;
type ParentType = gst_base::BaseTransform;
}
impl ObjectImpl for ImageCompare {
fn properties() -> &'static [glib::ParamSpec] {
static PROPERTIES: Lazy<Vec<glib::ParamSpec>> = Lazy::new(|| {
vec![
glib::ParamSpecString::new(
"location",
"Image Location",
"Filesystem location of image to be compared to video frames",
None,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
),
glib::ParamSpecEnum::new(
"hash-algo",
"Hashing Algorithm",
"Which hashing algorithm to use for image comparisons",
HashAlgorithm::static_type(),
DEFAULT_HASH_ALGO as i32,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
),
glib::ParamSpecUInt::new(
"max-dist-threshold",
"Maximum Distance Threshold",
"Maximum distance threshold to emit messages when an image is detected, by default emits only on exact match",
0,
u32::MAX,
DEFAULT_MAX_DISTANCE_THRESHOLD,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_PLAYING,
),
]
});
PROPERTIES.as_ref()
}
fn set_property(
&self,
obj: &Self::Type,
_id: usize,
value: &glib::Value,
pspec: &glib::ParamSpec,
) {
match pspec.name() {
"location" => {
let mut settings = self.settings.lock().unwrap();
let image_location = value.get().expect("type checked upstream");
if settings.location != image_location {
gst::info!(
CAT,
obj: obj,
"Changing location from {:?} to {:?}",
settings.location,
image_location
);
settings.location = image_location;
}
}
"hash-algo" => {
let mut settings = self.settings.lock().unwrap();
let hash_algo = value.get().expect("type checked upstream");
if settings.hash_algo != hash_algo {
gst::info!(
CAT,
obj: obj,
"Changing hash-algo from {:?} to {:?}",
settings.hash_algo,
hash_algo
);
settings.hash_algo = hash_algo;
}
}
"max-dist-threshold" => {
let mut settings = self.settings.lock().unwrap();
let max_distance_threshold = value.get().expect("type checked upstream");
if settings.max_distance_threshold != max_distance_threshold {
gst::info!(
CAT,
obj: obj,
"Changing max-dist-threshold from {} to {}",
settings.max_distance_threshold,
max_distance_threshold
);
settings.max_distance_threshold = max_distance_threshold;
}
}
_ => unimplemented!(),
}
}
fn property(&self, _obj: &Self::Type, _id: usize, pspec: &glib::ParamSpec) -> glib::Value {
match pspec.name() {
"location" => {
let settings = self.settings.lock().unwrap();
settings.location.to_value()
}
"hash-algo" => {
let settings = self.settings.lock().unwrap();
settings.hash_algo.to_value()
}
"max-dist-threshold" => {
let settings = self.settings.lock().unwrap();
settings.max_distance_threshold.to_value()
}
_ => unimplemented!(),
}
}
}
impl GstObjectImpl for ImageCompare {}
impl ElementImpl for ImageCompare {
fn metadata() -> Option<&'static gst::subclass::ElementMetadata> {
static ELEMENT_METADATA: Lazy<gst::subclass::ElementMetadata> = Lazy::new(|| {
gst::subclass::ElementMetadata::new(
"Image comparison",
"Filter/Video",
"Detects images in a video",
"Rafael Caricio <rafael@caricio.com>",
)
});
Some(&*ELEMENT_METADATA)
}
fn pad_templates() -> &'static [gst::PadTemplate] {
static PAD_TEMPLATES: Lazy<Vec<gst::PadTemplate>> = Lazy::new(|| {
let formats = gst::List::new([VideoFormat::Rgb.to_str(), VideoFormat::Rgba.to_str()]);
let caps = gst::Caps::builder("video/x-raw")
.field("format", &formats)
.field("width", gst::IntRange::new(1, i32::MAX))
.field("height", gst::IntRange::new(1, i32::MAX))
.field(
"framerate",
gst::FractionRange::new(
gst::Fraction::new(0, 1),
gst::Fraction::new(i32::MAX, 1),
),
)
.build();
let sink_pad_template = gst::PadTemplate::new(
"sink",
gst::PadDirection::Sink,
gst::PadPresence::Always,
&caps,
)
.unwrap();
let src_pad_template = gst::PadTemplate::new(
"src",
gst::PadDirection::Src,
gst::PadPresence::Always,
&caps,
)
.unwrap();
vec![sink_pad_template, src_pad_template]
});
PAD_TEMPLATES.as_ref()
}
}
impl BaseTransformImpl for ImageCompare {
const MODE: gst_base::subclass::BaseTransformMode =
gst_base::subclass::BaseTransformMode::AlwaysInPlace;
const PASSTHROUGH_ON_SAME_CAPS: bool = false;
const TRANSFORM_IP_ON_PASSTHROUGH: bool = false;
fn set_caps(
&self,
element: &Self::Type,
incaps: &gst::Caps,
outcaps: &gst::Caps,
) -> Result<(), gst::LoggableError> {
let in_info = match gst_video::VideoInfo::from_caps(incaps) {
Err(_) => return Err(gst::loggable_error!(CAT, "Failed to parse input caps")),
Ok(info) => info,
};
let out_info = match gst_video::VideoInfo::from_caps(outcaps) {
Err(_) => return Err(gst::loggable_error!(CAT, "Failed to parse output caps")),
Ok(info) => info,
};
gst::debug!(
CAT,
obj: element,
"Configured for caps {} to {}",
incaps,
outcaps
);
let color_format = match in_info.format() {
VideoFormat::Rgb => ColorType::Rgb8,
VideoFormat::Rgba => ColorType::Rgba8,
video_fmt => {
return Err(gst::loggable_error!(
CAT,
"Not supported video format: {}",
video_fmt
))
}
};
let settings = self.settings.lock().unwrap();
let image_location = settings.location.as_ref().ok_or_else(|| {
gst::loggable_error!(CAT, "Location of the reference image to compare not defined, set the property \"location\" to a valid file")
})?;
let mut state = self.state.lock().unwrap();
*state = Some(State::new(
settings.hash_algo,
image_location,
color_format,
out_info,
)?);
gst::debug!(
CAT,
obj: element,
"Loaded reference image hash: {:?}",
state.as_ref().unwrap().reference_image.to_base64()
);
Ok(())
}
fn transform_ip(
&self,
element: &Self::Type,
buf: &mut gst::BufferRef,
) -> Result<gst::FlowSuccess, gst::FlowError> {
if let Some(detection) = self.compare_frame(element, buf)? {
self.publish_image_detection(element, &detection)
}
Ok(gst::FlowSuccess::Ok)
}
fn propose_allocation(
&self,
element: &Self::Type,
decide_query: Option<&gst::query::Allocation>,
query: &mut gst::query::Allocation,
) -> Result<(), gst::LoggableError> {
query.add_allocation_meta::<gst_video::VideoMeta>(None);
self.parent_propose_allocation(element, decide_query, query)
}
}
/// Helper method that takes a gstreamer video-frame and copies it into a
/// tightly packed rgb(a) buffer, ready for consumption.
fn tightly_packed_framebuffer(frame: &gst_video::VideoFrameRef<&gst::BufferRef>) -> Vec<u8> {
assert_eq!(frame.n_planes(), 1); // RGB and RGBA are tightly packed
let line_size = (frame.width() * frame.n_components()) as usize;
let line_stride = frame.plane_stride()[0] as usize;
let mut raw_frame = Vec::with_capacity(line_size * frame.info().height() as usize);
// copy gstreamer frame to tightly packed rgb(a) frame.
frame
.plane_data(0)
.unwrap()
.chunks_exact(line_stride)
.map(|padded_line| &padded_line[..line_size])
.for_each(|line| raw_frame.extend_from_slice(line));
raw_frame
}
// Copyright (C) 2022 Rafael Caricio <rafael@caricio.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at
// <https://mozilla.org/MPL/2.0/>.
//
// SPDX-License-Identifier: MPL-2.0
use gst::prelude::*;
use gst::{glib, ClockTime};
use image_hasher::HashAlg;
mod imp;
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)]
#[repr(u32)]
#[enum_type(name = "GstImageCompareHashAlgorithm")]
#[non_exhaustive]
pub enum HashAlgorithm {
#[enum_value(name = "Mean: The Mean hashing algorithm.", nick = "mean")]
Mean = 0,
#[enum_value(name = "Gradient: The Gradient hashing algorithm.", nick = "gradient")]
Gradient = 1,
#[enum_value(
name = "VertGradient: The Vertical-Gradient hashing algorithm.",
nick = "vertgradient"
)]
VertGradient = 2,
#[enum_value(
name = "DoubleGradient: The Double-Gradient hashing algorithm.",
nick = "doublegradient"
)]
DoubleGradient = 3,
#[enum_value(
name = "Blockhash: The [Blockhash](https://github.com/commonsmachinery/blockhash-rfc) algorithm.",
nick = "blockhash"
)]
Blockhash = 4,
}
impl From<HashAlgorithm> for HashAlg {
fn from(ha: HashAlgorithm) -> Self {
use HashAlgorithm::*;
match ha {
Mean => Self::Mean,
Gradient => Self::Gradient,
VertGradient => Self::VertGradient,
DoubleGradient => Self::DoubleGradient,
Blockhash => Self::Blockhash,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ImageDetected {
distance: u32,
image_base64_hash: String,
running_time: Option<gst::ClockTime>,
pts: Option<gst::ClockTime>,
}
impl ImageDetected {
pub fn new(
distance: u32,
image_base64_hash: String,
running_time: Option<gst::ClockTime>,
pts: Option<ClockTime>,
) -> Self {
Self {
distance,
image_base64_hash,
running_time,
pts,
}
}
}
glib::wrapper! {
pub struct ImageCompare(ObjectSubclass<imp::ImageCompare>) @extends gst_base::BaseTransform, gst::Element, gst::Object;
}
pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
gst::Element::register(
Some(plugin),
"imgcmp",
gst::Rank::None,
ImageCompare::static_type(),
)
}
......@@ -10,10 +10,12 @@
mod border;
mod colordetect;
mod imgcmp;
fn plugin_init(plugin: &gst::Plugin) -> Result<(), gst::glib::BoolError> {
border::register(plugin)?;
colordetect::register(plugin)
colordetect::register(plugin)?;
imgcmp::register(plugin)
}
gst::plugin_define!(
......
Supports Markdown