Drop `UnixSocketAddrBound` on `close()`

This commit is contained in:
Ruihan Li 2025-09-06 22:27:23 +08:00 committed by Tate, Hongliang Tian
parent 286d4d4466
commit 87640d4b27
3 changed files with 77 additions and 36 deletions

View File

@ -5,6 +5,8 @@ use core::{
sync::atomic::{AtomicBool, Ordering}, sync::atomic::{AtomicBool, Ordering},
}; };
use spin::Once;
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
fs::utils::{Endpoint, EndpointState}, fs::utils::{Endpoint, EndpointState},
@ -23,10 +25,10 @@ use crate::{
}; };
pub(super) struct Connected { pub(super) struct Connected {
// `addr` should be dropped as soon as the socket file is closed,
// so it must not belong to `Inner`.
addr: Option<UnixSocketAddrBound>,
inner: Endpoint<Inner>, inner: Endpoint<Inner>,
reader: Mutex<RbConsumer<u8>>,
writer: Mutex<RbProducer<u8>>,
peer_cred: SocketCred,
} }
impl Connected { impl Connected {
@ -43,55 +45,71 @@ impl Connected {
let (peer_writer, this_reader) = RingBuffer::new(UNIX_STREAM_DEFAULT_BUF_SIZE).split(); let (peer_writer, this_reader) = RingBuffer::new(UNIX_STREAM_DEFAULT_BUF_SIZE).split();
let this_inner = Inner { let this_inner = Inner {
addr: SpinLock::new(addr), addr: Once::new(),
state, state,
is_pass_cred: AtomicBool::new(options.pass_cred()), reader: Mutex::new(this_reader),
writer: Mutex::new(this_writer),
all_aux: Mutex::new(VecDeque::new()), all_aux: Mutex::new(VecDeque::new()),
has_aux: AtomicBool::new(false), has_aux: AtomicBool::new(false),
is_pass_cred: AtomicBool::new(options.pass_cred()),
cred,
}; };
let peer_inner = Inner { let peer_inner = Inner {
addr: SpinLock::new(peer_addr), addr: Once::new(),
state: peer_state, state: peer_state,
is_pass_cred: AtomicBool::new(false), reader: Mutex::new(peer_reader),
writer: Mutex::new(peer_writer),
all_aux: Mutex::new(VecDeque::new()), all_aux: Mutex::new(VecDeque::new()),
has_aux: AtomicBool::new(false), has_aux: AtomicBool::new(false),
is_pass_cred: AtomicBool::new(false),
cred: peer_cred,
}; };
if let Some(addr) = addr.as_ref() {
this_inner.addr.call_once(|| addr.clone().into());
}
if let Some(peer_addr) = peer_addr.as_ref() {
peer_inner.addr.call_once(|| peer_addr.clone().into());
}
let (this_inner, peer_inner) = Endpoint::new_pair(this_inner, peer_inner); let (this_inner, peer_inner) = Endpoint::new_pair(this_inner, peer_inner);
let this = Connected { let this = Connected {
addr,
inner: this_inner, inner: this_inner,
reader: Mutex::new(this_reader),
writer: Mutex::new(this_writer),
peer_cred,
}; };
let peer = Connected { let peer = Connected {
addr: peer_addr,
inner: peer_inner, inner: peer_inner,
reader: Mutex::new(peer_reader),
writer: Mutex::new(peer_writer),
peer_cred: cred,
}; };
(this, peer) (this, peer)
} }
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> { pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
self.inner.this_end().addr.lock().clone() self.addr.as_ref()
} }
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> { pub(super) fn peer_addr(&self) -> UnixSocketAddr {
self.inner.peer_end().addr.lock().clone() self.inner
.peer_end()
.addr
.get()
.cloned()
.unwrap_or(UnixSocketAddr::Unnamed)
} }
pub(super) fn bind(&self, addr_to_bind: UnixSocketAddr) -> Result<()> { pub(super) fn bind(&mut self, addr_to_bind: UnixSocketAddr) -> Result<()> {
let mut addr = self.inner.this_end().addr.lock(); if self.addr.is_some() {
if addr.is_some() {
return addr_to_bind.bind_unnamed(); return addr_to_bind.bind_unnamed();
} }
let bound_addr = addr_to_bind.bind()?; let bound_addr = addr_to_bind.bind()?;
*addr = Some(bound_addr); self.inner
.this_end()
.addr
.call_once(|| bound_addr.clone().into());
self.addr = Some(bound_addr);
Ok(()) Ok(())
} }
@ -103,19 +121,21 @@ impl Connected {
) -> Result<(usize, Vec<ControlMessage>)> { ) -> Result<(usize, Vec<ControlMessage>)> {
let is_empty = writer.is_empty(); let is_empty = writer.is_empty();
if is_empty && !is_seqpacket { if is_empty && !is_seqpacket {
if self.reader.lock().is_empty() { if self.inner.this_end().reader.lock().is_empty() {
return_errno_with_message!(Errno::EAGAIN, "the channel is empty"); return_errno_with_message!(Errno::EAGAIN, "the channel is empty");
} }
return Ok((0, Vec::new())); return Ok((0, Vec::new()));
} }
let mut reader = self.reader.lock(); let this_end = self.inner.this_end();
let peer_end = self.inner.peer_end();
let mut reader = this_end.reader.lock();
// `reader.len()` is an `Acquire` operation. So it can guarantee that the `has_aux` // `reader.len()` is an `Acquire` operation. So it can guarantee that the `has_aux`
// check below sees the up-to-date value. // check below sees the up-to-date value.
let no_aux_len = reader.len(); let no_aux_len = reader.len();
let peer_end = self.inner.peer_end(); let is_pass_cred = this_end.is_pass_cred.load(Ordering::Relaxed);
let is_pass_cred = self.inner.this_end().is_pass_cred.load(Ordering::Relaxed);
// Fast path: There are no auxiliary data to receive. // Fast path: There are no auxiliary data to receive.
if !peer_end.has_aux.load(Ordering::Relaxed) { if !peer_end.has_aux.load(Ordering::Relaxed) {
@ -233,7 +253,7 @@ impl Connected {
// Fast path: There are no auxiliary data to transmit. // Fast path: There are no auxiliary data to transmit.
if aux_data.is_empty() && !is_seqpacket && !need_pass_cred { if aux_data.is_empty() && !is_seqpacket && !need_pass_cred {
let mut writer = self.writer.lock(); let mut writer = this_end.writer.lock();
return self.inner.write_with(move || { return self.inner.write_with(move || {
if is_seqpacket && writer.free_len() < reader.sum_lens() { if is_seqpacket && writer.free_len() < reader.sum_lens() {
return Ok(0); return Ok(0);
@ -250,7 +270,7 @@ impl Connected {
// Write the payload bytes. // Write the payload bytes.
let (write_start, write_res) = if !is_empty { let (write_start, write_res) = if !is_empty {
let mut writer = self.writer.lock(); let mut writer = this_end.writer.lock();
let write_start = writer.tail(); let write_start = writer.tail();
let write_res = self.inner.write_with(move || { let write_res = self.inner.write_with(move || {
if is_seqpacket && writer.free_len() < reader.sum_lens() { if is_seqpacket && writer.free_len() < reader.sum_lens() {
@ -260,7 +280,7 @@ impl Connected {
}); });
(write_start, write_res) (write_start, write_res)
} else { } else {
(self.writer.lock().tail(), Ok(0)) (this_end.writer.lock().tail(), Ok(0))
}; };
let Ok(write_len) = write_res else { let Ok(write_len) = write_res else {
this_end this_end
@ -310,13 +330,14 @@ impl Connected {
} }
pub(super) fn check_io_events(&self) -> IoEvents { pub(super) fn check_io_events(&self) -> IoEvents {
let this_end = self.inner.this_end();
let mut events = IoEvents::empty(); let mut events = IoEvents::empty();
if !self.reader.lock().is_empty() { if !this_end.reader.lock().is_empty() {
events |= IoEvents::IN; events |= IoEvents::IN;
} }
if !self.writer.lock().is_full() { if !this_end.writer.lock().is_full() {
events |= IoEvents::OUT; events |= IoEvents::OUT;
} }
@ -328,7 +349,7 @@ impl Connected {
} }
pub(super) fn peer_cred(&self) -> &SocketCred { pub(super) fn peer_cred(&self) -> &SocketCred {
&self.peer_cred &self.inner.peer_end().cred
} }
} }
@ -340,12 +361,15 @@ impl Drop for Connected {
} }
struct Inner { struct Inner {
addr: SpinLock<Option<UnixSocketAddrBound>>, addr: Once<UnixSocketAddr>,
state: EndpointState, state: EndpointState,
is_pass_cred: AtomicBool,
// Lock order: `reader` -> `all_aux` & `all_aux` -> `writer` // Lock order: `reader` -> `all_aux` & `all_aux` -> `writer`
reader: Mutex<RbConsumer<u8>>,
writer: Mutex<RbProducer<u8>>,
all_aux: Mutex<VecDeque<RangedAuxiliaryData>>, all_aux: Mutex<VecDeque<RangedAuxiliaryData>>,
has_aux: AtomicBool, has_aux: AtomicBool,
is_pass_cred: AtomicBool,
cred: SocketCred,
} }
impl AsRef<EndpointState> for Inner { impl AsRef<EndpointState> for Inner {

View File

@ -362,7 +362,7 @@ impl Socket for UnixStreamSocket {
let addr = match self.state.read().as_ref() { let addr = match self.state.read().as_ref() {
State::Init(init) => init.addr().cloned(), State::Init(init) => init.addr().cloned(),
State::Listen(listen) => Some(listen.addr().clone()), State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(), State::Connected(connected) => connected.addr().cloned(),
}; };
Ok(addr.into()) Ok(addr.into())

View File

@ -238,11 +238,12 @@ END_TEST()
FN_TEST(bind_connected) FN_TEST(bind_connected)
{ {
int fildes[2]; int fildes[2], sk;
struct sockaddr_un addr; struct sockaddr_un addr;
socklen_t addrlen; socklen_t addrlen;
TEST_SUCC(socketpair(PF_UNIX, SOCK_TYPE, 0, fildes)); TEST_SUCC(socketpair(PF_UNIX, SOCK_TYPE, 0, fildes));
sk = TEST_SUCC(socket(PF_UNIX, SOCK_TYPE, 0));
TEST_SUCC(bind(fildes[0], (struct sockaddr *)&UNIX_ADDR("\0X"), TEST_SUCC(bind(fildes[0], (struct sockaddr *)&UNIX_ADDR("\0X"),
PATH_OFFSET + 2)); PATH_OFFSET + 2));
@ -269,8 +270,24 @@ FN_TEST(bind_connected)
TEST_SUCC(bind(fildes[1], (struct sockaddr *)&UNNAMED_ADDR, TEST_SUCC(bind(fildes[1], (struct sockaddr *)&UNNAMED_ADDR,
UNNAMED_ADDRLEN)); UNNAMED_ADDRLEN));
// Closing the socket will release the bound address.
// So another socket can bind to it again.
TEST_ERRNO(bind(sk, (struct sockaddr *)&UNIX_ADDR("\0X"),
PATH_OFFSET + 2),
EADDRINUSE);
TEST_SUCC(close(fildes[0])); TEST_SUCC(close(fildes[0]));
TEST_SUCC(bind(sk, (struct sockaddr *)&UNIX_ADDR("\0X"),
PATH_OFFSET + 2));
// But the released address is still "visible" from
// the previously connected socket.
addrlen = sizeof(addr);
TEST_RES(getpeername(fildes[1], (struct sockaddr *)&addr, &addrlen),
addrlen == PATH_OFFSET + 2 && memcmp(&addr, &UNIX_ADDR("\0X"),
PATH_OFFSET + 2) == 0);
TEST_SUCC(close(fildes[1])); TEST_SUCC(close(fildes[1]));
TEST_SUCC(close(sk));
} }
END_TEST() END_TEST()