Commit bc3bde99 authored by Zeeshan Ali's avatar Zeeshan Ali Committed by Marc-André Lureau
Browse files

zb: Interior mutability for FDs in a Message



Users shouldn't need to have mutable ref to incoming messages and
disowning FDs was the only real reason a user would need it so let's
just use interior mutability for FDs so users rarely need mutable ref to
a message.

This change becomes more important later on when we switch to
`Arc<Message>` for incoming messages.
Signed-off-by: default avatarMarc-André Lureau <marcandre.lureau@redhat.com>
parent 3a412a17
......@@ -262,7 +262,7 @@ impl<'a> Proxy<'a> {
B: serde::ser::Serialize + zvariant::Type,
R: serde::de::DeserializeOwned + zvariant::Type,
{
let mut reply = self.call_method(method_name, body).await?;
let reply = self.call_method(method_name, body).await?;
// Since we don't keep the reply msg around and user still might use the FDs after this
// call returns, we must disown the FDs so we don't end up closing them after the call.
reply.disown_fds();
......
......@@ -301,7 +301,7 @@ mod tests {
fn fdpass_systemd() {
let connection = crate::Connection::new_system().unwrap();
let mut reply = connection
let reply = connection
.call_method(
Some("org.freedesktop.systemd1"),
"/org/freedesktop/systemd1",
......
......@@ -3,6 +3,7 @@ use std::{
error, fmt,
io::{Cursor, Error as IOError},
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
sync::{Arc, RwLock},
};
use zvariant::{EncodingContext, Error as VariantError, ObjectPath, Signature, Type};
......@@ -14,6 +15,8 @@ use crate::{
};
const FIELDS_LEN_START_OFFSET: usize = 12;
const LOCK_PANIC_MSG: &str = "lock poisoned";
macro_rules! dbus_context {
($n_bytes_before: expr) => {
EncodingContext::<byteorder::NativeEndian>::new_dbus($n_bytes_before)
......@@ -187,7 +190,7 @@ where
Ok(Message {
bytes,
fds: Fds::Raw(fds),
fds: Arc::new(RwLock::new(Fds::Raw(fds))),
})
}
......@@ -280,7 +283,7 @@ impl Clone for Fds {
#[derive(Clone)]
pub struct Message {
bytes: Vec<u8>,
fds: Fds,
fds: Arc<RwLock<Fds>>,
}
// TODO: Handle non-native byte order: https://gitlab.freedesktop.org/dbus/zbus/-/issues/19
......@@ -371,7 +374,7 @@ impl Message {
}
let bytes = bytes.to_vec();
let fds = Fds::Raw(vec![]);
let fds = Arc::new(RwLock::new(Fds::Raw(vec![])));
Ok(Self { bytes, fds })
}
......@@ -385,8 +388,8 @@ impl Message {
Ok(())
}
pub(crate) fn set_owned_fds(&mut self, fds: Vec<OwnedFd>) {
self.fds = Fds::Owned(fds);
pub(crate) fn set_owned_fds(&self, fds: Vec<OwnedFd>) {
*self.fds.write().expect(LOCK_PANIC_MSG) = Fds::Owned(fds);
}
/// Disown the associated file descriptors.
......@@ -395,10 +398,11 @@ impl Message {
/// contain associated FDs. To prevent the message from closing
/// those FDs on drop, you may remove the ownership thanks to this
/// method, after that you are responsible for closing them.
pub fn disown_fds(&mut self) {
if let Fds::Owned(ref mut fds) = &mut self.fds {
pub fn disown_fds(&self) {
let mut fds_lock = self.fds.write().expect(LOCK_PANIC_MSG);
if let Fds::Owned(ref mut fds) = *fds_lock {
// From now on, it's the caller responsability to close the fds
self.fds = Fds::Raw(fds.drain(..).map(|fd| fd.into_raw_fd()).collect());
*fds_lock = Fds::Raw(fds.drain(..).map(|fd| fd.into_raw_fd()).collect());
}
}
......@@ -505,7 +509,7 @@ impl Message {
}
pub(crate) fn fds(&self) -> Vec<RawFd> {
match &self.fds {
match &*self.fds.read().expect(LOCK_PANIC_MSG) {
Fds::Raw(fds) => fds.clone(),
Fds::Owned(fds) => fds.iter().map(|f| f.as_raw_fd()).collect(),
}
......@@ -549,8 +553,9 @@ impl fmt::Debug for Message {
if let Ok(s) = self.body_signature() {
msg.field("body", &s);
}
if !self.fds().is_empty() {
msg.field("fds", &self.fds);
let fds = self.fds();
if !fds.is_empty() {
msg.field("fds", &fds);
}
msg.finish()
}
......@@ -630,7 +635,7 @@ mod tests {
)
.unwrap();
assert_eq!(m.body_signature().unwrap().to_string(), "hs");
assert_eq!(m.fds, Fds::Raw(vec![stdout.as_raw_fd()]));
assert_eq!(*m.fds.read().unwrap(), Fds::Raw(vec![stdout.as_raw_fd()]));
let body: Result<u32, MessageError> = m.body();
assert_eq!(body.unwrap_err(), MessageError::UnmatchedBodySignature);
......
......@@ -144,7 +144,7 @@ impl<S: Socket> Connection<S> {
}
// If we reach here, the message is complete, return it
let mut msg = self.msg_in_buffer.take().unwrap();
let msg = self.msg_in_buffer.take().unwrap();
msg.set_owned_fds(std::mem::take(&mut self.raw_in_fds));
Ok(msg)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment