Commit cf489ca9 authored by Zeeshan Ali's avatar Zeeshan Ali
Browse files

Merge branch 'preliminary' into 'main'

Preliminary patches for signal broadcast

See merge request !295
parents ca5da70c e796a7f0
Pipeline #315862 passed with stages
in 4 minutes and 51 seconds
......@@ -43,6 +43,8 @@ zbus_polkit = { path = "../zbus_polkit", version = "1" }
doc-comment = "0.3.3"
futures-util = "0.3.8" # activate default features
ntest = "0.7.1"
env_logger = "*"
test-env-log = "0.2.6"
[package.metadata.docs.rs]
all-features = true
......
......@@ -120,6 +120,7 @@ mod tests {
use super::Address;
use crate::Error;
use std::str::FromStr;
use test_env_log::test;
#[test]
fn parse_dbus_addresses() {
......
......@@ -410,7 +410,7 @@ impl Connection {
/// not. All messages received during this call that are not returned by it, are pushed to the
/// queue to be picked by the susubsequent or awaiting call to this method or by the
/// `MessageStream`.
pub async fn receive_specific<P>(&self, predicate: P) -> Result<Message>
pub async fn receive_specific<P>(&self, predicate: P) -> Result<Arc<Message>>
where
for<'msg> P: Fn(&'msg Message) -> BoxFuture<'msg, Result<bool>>,
{
......@@ -427,7 +427,7 @@ impl Connection {
} {
// SAFETY: we got the index from the queue enumerator so this shouldn't ever
// fail.
return queue.remove(i).expect("removing queue item");
return queue.remove(i).expect("removing queue item").map(Arc::new);
}
}
}
......@@ -470,7 +470,7 @@ impl Connection {
}
}
msg
msg.map(Arc::new)
}
/// Send `msg` to the peer.
......@@ -500,7 +500,7 @@ impl Connection {
interface: Option<&str>,
method_name: &str,
body: &B,
) -> Result<Message>
) -> Result<Arc<Message>>
where
B: serde::ser::Serialize + zvariant::Type,
MessageError: From<E>,
......@@ -606,10 +606,9 @@ impl Connection {
///
/// This method can fail if `msg` is corrupt.
pub fn assign_serial_num(&self, msg: &mut Message) -> Result<u32> {
let serial = self.next_serial();
let mut serial = 0;
msg.modify_primary_header(|primary| {
primary.set_serial_num(serial);
serial = *primary.serial_num_or_init(|| self.next_serial());
Ok(())
})?;
......@@ -884,11 +883,11 @@ impl futures_sink::Sink<Message> for MessageSink<'_> {
/// the `receive_specific` is waiting for and end up in a deadlock situation. It is therefore highly
/// recommended not to use such a combination.
pub struct MessageStream {
stream: stream::BoxStream<'static, Result<Message>>,
stream: stream::BoxStream<'static, Result<Arc<Message>>>,
}
impl stream::Stream for MessageStream {
type Item = Result<Message>;
type Item = Result<Arc<Message>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
stream::Stream::poll_next(self.get_mut().stream.as_mut(), cx)
......@@ -933,6 +932,7 @@ impl From<crate::Connection> for Connection {
mod tests {
use futures_util::stream::TryStreamExt;
use std::os::unix::net::UnixStream;
use test_env_log::test;
use super::*;
......@@ -952,7 +952,7 @@ mod tests {
let (client_conn, server_conn) = futures_util::try_join!(client, server)?;
let server_future = async {
let mut method: Option<Message> = None;
let mut method: Option<Arc<Message>> = None;
while let Some(m) = server_conn.stream().await.try_next().await? {
if m.to_string() == "Method call Test" {
method.replace(m);
......
......@@ -144,6 +144,7 @@ where
mod tests {
use nix::unistd::Uid;
use std::os::unix::net::UnixStream;
use test_env_log::test;
use super::*;
......
......@@ -236,7 +236,7 @@ impl<'a> Proxy<'a> {
/// allocation/copying, by deserializing the reply to an unowned type).
///
/// [`call`]: struct.Proxy.html#method.call
pub async fn call_method<B>(&self, method_name: &str, body: &B) -> Result<Message>
pub async fn call_method<B>(&self, method_name: &str, body: &B) -> Result<Arc<Message>>
where
B: serde::ser::Serialize + zvariant::Type,
{
......@@ -262,7 +262,7 @@ impl<'a> Proxy<'a> {
B: serde::ser::Serialize + zvariant::Type,
R: serde::de::DeserializeOwned + zvariant::Type,
{
let mut reply = self.call_method(method_name, body).await?;
let reply = self.call_method(method_name, body).await?;
// Since we don't keep the reply msg around and user still might use the FDs after this
// call returns, we must disown the FDs so we don't end up closing them after the call.
reply.disown_fds();
......@@ -422,7 +422,7 @@ impl<'a> Proxy<'a> {
/// # Errors
///
/// This method returns the same errors as [`Self::receive_signal`].
pub async fn next_signal(&self) -> Result<Option<Message>> {
pub async fn next_signal(&self) -> Result<Option<Arc<Message>>> {
let msg = {
// We want to keep a lock on the handlers during `receive_specific` call but we also
// want to avoid using `handlers` directly as that somehow makes this call (or rather
......@@ -550,7 +550,7 @@ impl<'a> Proxy<'a> {
#[derivative(Debug)]
pub struct SignalStream<'s> {
#[derivative(Debug = "ignore")]
stream: stream::BoxStream<'s, Message>,
stream: stream::BoxStream<'s, Arc<Message>>,
conn: Connection,
subscription_id: Option<u64>,
}
......@@ -574,7 +574,7 @@ impl SignalStream<'_> {
}
impl stream::Stream for SignalStream<'_> {
type Item = Message;
type Item = Arc<Message>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
stream::Stream::poll_next(self.get_mut().stream.as_mut(), cx)
......@@ -601,6 +601,7 @@ mod tests {
use super::*;
use futures_util::future::FutureExt;
use std::{future::ready, sync::Arc};
use test_env_log::test;
#[test]
fn signal_stream() {
......@@ -669,13 +670,7 @@ mod tests {
let name_acquired_signaled = Arc::new(Mutex::new(false));
let name_acquired_signaled2 = Arc::new(Mutex::new(false));
let proxy = Proxy::new(
&conn,
"org.freedesktop.DBus",
"/org/freedesktop/DBus",
"org.freedesktop.DBus",
)?;
let proxy = fdo::AsyncDBusProxy::new(&conn);
let well_known = "org.freedesktop.zbus.async.ProxySignalConnectTest";
let unique_name = conn.unique_name().unwrap().to_string();
let name_owner_changed_id = {
......
......@@ -6,6 +6,7 @@ use std::{
io::{AsRawFd, RawFd},
net::UnixStream,
},
sync::Arc,
};
use zvariant::ObjectPath;
......@@ -53,11 +54,13 @@ use crate::{azync, Guid, Message, MessageError, Result};
/// [`receive_message`]: struct.Connection.html#method.receive_message
/// [`set_max_queued`]: struct.Connection.html#method.set_max_queued
#[derive(Debug, Clone)]
pub struct Connection(azync::Connection);
pub struct Connection {
inner: azync::Connection,
}
impl AsRawFd for Connection {
fn as_raw_fd(&self) -> RawFd {
block_on(self.0.as_raw_fd())
block_on(self.inner.as_raw_fd())
}
}
......@@ -70,24 +73,24 @@ impl Connection {
/// Upon successful return, the connection is fully established and negotiated: D-Bus messages
/// can be sent and received.
pub fn new_unix_client(stream: UnixStream, bus_connection: bool) -> Result<Self> {
block_on(azync::Connection::new_unix_client(stream, bus_connection)).map(Self)
block_on(azync::Connection::new_unix_client(stream, bus_connection)).map(Self::from)
}
/// Create a `Connection` to the session/user message bus.
pub fn new_session() -> Result<Self> {
block_on(azync::Connection::new_session()).map(Self)
block_on(azync::Connection::new_session()).map(Self::from)
}
/// Create a `Connection` to the system-wide message bus.
pub fn new_system() -> Result<Self> {
block_on(azync::Connection::new_system()).map(Self)
block_on(azync::Connection::new_system()).map(Self::from)
}
/// Create a `Connection` for the given [D-Bus address].
///
/// [D-Bus address]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses
pub fn new_for_address(address: &str, bus_connection: bool) -> Result<Self> {
block_on(azync::Connection::new_for_address(address, bus_connection)).map(Self)
block_on(azync::Connection::new_for_address(address, bus_connection)).map(Self::from)
}
/// Create a server `Connection` for the given `UnixStream` and the server `guid`.
......@@ -98,12 +101,12 @@ impl Connection {
/// Upon successful return, the connection is fully established and negotiated: D-Bus messages
/// can be sent and received.
pub fn new_unix_server(stream: UnixStream, guid: &Guid) -> Result<Self> {
block_on(azync::Connection::new_unix_server(stream, guid)).map(Self)
block_on(azync::Connection::new_unix_server(stream, guid)).map(Self::from)
}
/// Max number of messages to queue.
pub fn max_queued(&self) -> usize {
self.0.max_queued()
self.inner.max_queued()
}
/// Set the max number of messages to queue.
......@@ -124,17 +127,17 @@ impl Connection {
///# Ok::<_, Box<dyn Error + Send + Sync>>(())
/// ```
pub fn set_max_queued(self, max: usize) -> Self {
Self(self.0.set_max_queued(max))
Self::from(self.inner.set_max_queued(max))
}
/// The server's GUID.
pub fn server_guid(&self) -> &str {
self.0.server_guid()
self.inner.server_guid()
}
/// The unique name as assigned by the message bus or `None` if not a message bus connection.
pub fn unique_name(&self) -> Option<&str> {
self.0.unique_name()
self.inner.unique_name()
}
/// Fetch the next message from the connection.
......@@ -150,8 +153,8 @@ impl Connection {
/// with situation where this method takes away the message the other API is awaiting for and
/// end up in a deadlock situation. It is therefore highly recommended not to use such a
/// combination.
pub fn receive_message(&self) -> Result<Message> {
block_on(self.0.receive_specific(|_| ready(Ok(true)).boxed()))
pub fn receive_message(&self) -> Result<Arc<Message>> {
self.receive_specific(|_| Ok(true))
}
/// Receive a specific message.
......@@ -160,11 +163,14 @@ impl Connection {
/// decides if the message received should be returned by this method or not. Message received
/// during this call that are not returned by it, are pushed to the queue to be picked by the
/// susubsequent call to `receive_message`] or this method.
pub fn receive_specific<P>(&self, predicate: P) -> Result<Message>
pub fn receive_specific<P>(&self, predicate: P) -> Result<Arc<Message>>
where
P: Fn(&Message) -> Result<bool>,
{
block_on(self.0.receive_specific(|msg| ready(predicate(msg)).boxed()))
block_on(
self.inner
.receive_specific(|msg| ready(predicate(msg)).boxed()),
)
}
/// Send `msg` to the peer.
......@@ -179,7 +185,7 @@ impl Connection {
///
/// [`flush`]: struct.Connection.html#method.flush
pub fn send_message(&self, msg: Message) -> Result<u32> {
block_on(self.0.send_message(msg))
block_on(self.inner.send_message(msg))
}
/// Send a method call.
......@@ -204,13 +210,13 @@ impl Connection {
iface: Option<&str>,
method_name: &str,
body: &B,
) -> Result<Message>
) -> Result<Arc<Message>>
where
B: serde::ser::Serialize + zvariant::Type,
MessageError: From<E>,
{
block_on(
self.0
self.inner
.call_method(destination, path, iface, method_name, body),
)
}
......@@ -231,7 +237,7 @@ impl Connection {
MessageError: From<E>,
{
block_on(
self.0
self.inner
.emit_signal(destination, path, iface, signal_name, body),
)
}
......@@ -246,7 +252,7 @@ impl Connection {
where
B: serde::ser::Serialize + zvariant::Type,
{
block_on(self.0.reply(call, body))
block_on(self.inner.reply(call, body))
}
/// Reply an error to a message.
......@@ -259,36 +265,37 @@ impl Connection {
where
B: serde::ser::Serialize + zvariant::Type,
{
block_on(self.0.reply_error(call, error_name, body))
block_on(self.inner.reply_error(call, error_name, body))
}
/// Checks if `self` is a connection to a message bus.
///
/// This will return `false` for p2p connections.
pub fn is_bus(&self) -> bool {
self.0.is_bus()
self.inner.is_bus()
}
/// Get a reference to the underlying async Connection.
pub fn inner(&self) -> &azync::Connection {
&self.0
&self.inner
}
/// Get the underlying async Connection, consuming `self`.
pub fn into_inner(self) -> azync::Connection {
self.0
self.inner
}
}
impl From<azync::Connection> for Connection {
fn from(conn: azync::Connection) -> Self {
Self(conn)
Self { inner: conn }
}
}
#[cfg(test)]
mod tests {
use std::{os::unix::net::UnixStream, thread};
use test_env_log::test;
use crate::{Connection, Error, Guid};
#[test]
......
use std::{convert::Infallible, error, fmt, io};
use std::{convert::Infallible, error, fmt, io, sync::Arc};
use zvariant::Error as VariantError;
use crate::{fdo, Message, MessageError, MessageType};
......@@ -26,7 +26,7 @@ pub enum Error {
/// A D-Bus method error reply.
// According to the spec, there can be all kinds of details in D-Bus errors but nobody adds anything more than a
// string description.
MethodError(String, Option<String>, Message),
MethodError(String, Option<String>, Arc<Message>),
/// Invalid D-Bus GUID.
InvalidGUID,
/// Unsupported function, or support currently lacking.
......@@ -148,6 +148,12 @@ impl From<Infallible> for Error {
// For messages that are D-Bus error returns
impl From<Message> for Error {
fn from(message: Message) -> Error {
Self::from(Arc::new(message))
}
}
impl From<Arc<Message>> for Error {
fn from(message: Arc<Message>) -> Error {
// FIXME: Instead of checking this, we should have Method as trait and specific types for
// each message type.
let header = match message.header() {
......
......@@ -604,6 +604,7 @@ mod tests {
convert::TryInto,
sync::{Arc, Mutex},
};
use test_env_log::test;
#[test]
fn error_from_zerror() {
......
......@@ -72,6 +72,7 @@ impl FromStr for Guid {
#[cfg(test)]
mod tests {
use crate::Guid;
use test_env_log::test;
#[test]
fn generate() {
......
......@@ -808,6 +808,7 @@ impl FromStr for Command {
#[cfg(test)]
mod tests {
use std::os::unix::net::UnixStream;
use test_env_log::test;
use super::*;
......
......@@ -214,6 +214,7 @@ mod tests {
use enumflags2::BitFlags;
use ntest::timeout;
use test_env_log::test;
use zvariant::{Fd, OwnedObjectPath, OwnedValue, Type};
......@@ -236,13 +237,13 @@ mod tests {
.unwrap();
m.modify_primary_header(|primary| {
primary.set_flags(BitFlags::from(MessageFlags::NoAutoStart));
primary.set_serial_num(11);
primary.serial_num_or_init(|| 11);
Ok(())
})
.unwrap();
let primary = m.primary_header().unwrap();
assert!(primary.serial_num() == 11);
let primary = m.primary_header();
assert!(*primary.serial_num().unwrap() == 11);
assert!(primary.flags() == MessageFlags::NoAutoStart);
}
......@@ -300,7 +301,7 @@ mod tests {
fn fdpass_systemd() {
let connection = crate::Connection::new_system().unwrap();
let mut reply = connection
let reply = connection
.call_method(
Some("org.freedesktop.systemd1"),
"/org/freedesktop/systemd1",
......@@ -553,7 +554,7 @@ mod tests {
loop {
let msg = conn.receive_message().unwrap();
if msg.primary_header().unwrap().serial_num() == serial {
if *msg.primary_header().serial_num().unwrap() == serial {
break;
}
}
......
......@@ -3,6 +3,7 @@ use std::{
error, fmt,
io::{Cursor, Error as IOError},
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
sync::{Arc, RwLock},
};
use zvariant::{EncodingContext, Error as VariantError, ObjectPath, Signature, Type};
......@@ -14,6 +15,8 @@ use crate::{
};
const FIELDS_LEN_START_OFFSET: usize = 12;
const LOCK_PANIC_MSG: &str = "lock poisoned";
macro_rules! dbus_context {
($n_bytes_before: expr) => {
EncodingContext::<byteorder::NativeEndian>::new_dbus($n_bytes_before)
......@@ -165,8 +168,11 @@ where
} = self;
if let Some(reply_to) = reply_to.as_ref() {
let serial = reply_to.primary().serial_num();
fields.add(MessageField::ReplySerial(serial));
let serial = reply_to
.primary()
.serial_num()
.ok_or(MessageError::MissingField)?;
fields.add(MessageField::ReplySerial(*serial));
if let Some(sender) = reply_to.sender()? {
fields.add(MessageField::Destination(sender.into()));
......@@ -186,8 +192,9 @@ where
let (_, fds) = zvariant::to_writer_fds(&mut cursor, ctxt, body)?;
Ok(Message {
primary_header: header.into_primary(),
bytes,
fds: Fds::Raw(fds),
fds: Arc::new(RwLock::new(Fds::Raw(fds))),
})
}
......@@ -279,8 +286,9 @@ impl Clone for Fds {
/// [`Connection`]: struct.Connection#method.call_method
#[derive(Clone)]
pub struct Message {
primary_header: MessagePrimaryHeader,
bytes: Vec<u8>,
fds: Fds,
fds: Arc<RwLock<Fds>>,
}
// TODO: Handle non-native byte order: https://gitlab.freedesktop.org/dbus/zbus/-/issues/19
......@@ -370,9 +378,15 @@ impl Message {
return Err(MessageError::IncorrectEndian);
}
let primary_header =
zvariant::from_slice(bytes, dbus_context!(0)).map_err(MessageError::from)?;
let bytes = bytes.to_vec();
let fds = Fds::Raw(vec![]);
Ok(Self { bytes, fds })
let fds = Arc::new(RwLock::new(Fds::Raw(vec![])));
Ok(Self {
primary_header,
bytes,
fds,
})
}
pub(crate) fn add_bytes(&mut self, bytes: &[u8]) -> Result<(), MessageError> {
......@@ -385,8 +399,8 @@ impl Message {
Ok(())
}
pub(crate) fn set_owned_fds(&mut self, fds: Vec<OwnedFd>) {
self.fds = Fds::Owned(fds);
pub(crate) fn set_owned_fds(&self, fds: Vec<OwnedFd>) {
*self.fds.write().expect(LOCK_PANIC_MSG) = Fds::Owned(fds);
}
/// Disown the associated file descriptors.
......@@ -395,17 +409,18 @@ impl Message {
/// contain associated FDs. To prevent the message from closing
/// those FDs on drop, you may remove the ownership thanks to this
/// method, after that you are responsible for closing them.
pub fn disown_fds(&mut self) {
if let Fds::Owned(ref mut fds) = &mut self.fds {
pub fn disown_fds(&self) {
let mut fds_lock = self.fds.write().expect(LOCK_PANIC_MSG);
if let Fds::Owned(ref mut fds) = *fds_lock {
// From now on, it's the caller responsability to close the fds
self.fds = Fds::Raw(fds.drain(..).map(|fd| fd.into_raw_fd()).collect());
*fds_lock = Fds::Raw(fds.drain(..).map(|fd| fd.into_raw_fd()).collect());
}
}
pub(crate) fn bytes_to_completion(&self) -> Result<usize, MessageError> {
let header_len = MIN_MESSAGE_SIZE + self.fields_len()?;
let body_padding = padding_for_8_bytes(header_len);
let body_len = self.primary_header()?.body_len();
let body_len = self.primary_header().body_len();
let required = header_len + body_padding + body_len as usize;
Ok(required - self.bytes.len())
......@@ -429,20 +444,18 @@ impl Message {
}
}
/// Deserialize the primary header.
pub fn primary_header(&self) -> Result<MessagePrimaryHeader, MessageError> {
zvariant::from_slice(&self.bytes, dbus_context!(0)).map_err(MessageError::from)
pub fn primary_header(&self) -> &MessagePrimaryHeader {
&self.primary_header
}
pub(crate) fn modify_primary_header<F>(&mut self, mut modifier: F) -> Result<(), MessageError>
where
F: FnMut(&mut MessagePrimaryHeader) -> Result<(), MessageError>,
{
let mut primary = self.primary_header()?;
modifier(&mut primary)?;
modifier(&mut self.primary_header)?;
let mut cursor = Cursor::new(&mut self.bytes);
zvariant::to_writer(&mut cursor, dbus_context!(0), &primary)
zvariant::to_writer(&mut cursor, dbus_context!(0), &self.primary_header)
.map(|_| ())
.map_err(MessageError::from)
}
......@@ -505,7 +518,7 @@ impl Message {
}
pub(crate) fn fds(&self) -> Vec<RawFd> {
match &self.fds {
match &*self.fds.read().expect(LOCK_PANIC_MSG) {
Fds::Raw(fds) => fds.clone(),
Fds::Owned(fds) => fds.iter().map(|f| f.as_raw_fd()).collect(),
}
......@@ -549,8 +562,9 @@ impl fmt::Debug for Message {
if let Ok(s) = self.body_signature() {
msg.field("body", &s);
}
if !self.fds().is_empty() {
msg.field("fds", &self.fds);
let fds = self.fds();
if !fds.is_empty() {
msg.field("fds", &fds);
}
msg.finish()
}
......@@ -614,6 +628,7 @@ impl fmt::Display for Message {
mod tests {