diff --git a/kernel/src/net/socket/unix/stream/connected.rs b/kernel/src/net/socket/unix/stream/connected.rs index 1f1e2dd50..7a3665b8b 100644 --- a/kernel/src/net/socket/unix/stream/connected.rs +++ b/kernel/src/net/socket/unix/stream/connected.rs @@ -5,6 +5,8 @@ use core::{ sync::atomic::{AtomicBool, Ordering}, }; +use spin::Once; + use crate::{ events::IoEvents, fs::utils::{Endpoint, EndpointState}, @@ -23,10 +25,10 @@ use crate::{ }; pub(super) struct Connected { + // `addr` should be dropped as soon as the socket file is closed, + // so it must not belong to `Inner`. + addr: Option, inner: Endpoint, - reader: Mutex>, - writer: Mutex>, - peer_cred: SocketCred, } impl Connected { @@ -43,55 +45,71 @@ impl Connected { let (peer_writer, this_reader) = RingBuffer::new(UNIX_STREAM_DEFAULT_BUF_SIZE).split(); let this_inner = Inner { - addr: SpinLock::new(addr), + addr: Once::new(), state, - is_pass_cred: AtomicBool::new(options.pass_cred()), + reader: Mutex::new(this_reader), + writer: Mutex::new(this_writer), all_aux: Mutex::new(VecDeque::new()), has_aux: AtomicBool::new(false), + is_pass_cred: AtomicBool::new(options.pass_cred()), + cred, }; let peer_inner = Inner { - addr: SpinLock::new(peer_addr), + addr: Once::new(), state: peer_state, - is_pass_cred: AtomicBool::new(false), + reader: Mutex::new(peer_reader), + writer: Mutex::new(peer_writer), all_aux: Mutex::new(VecDeque::new()), has_aux: AtomicBool::new(false), + is_pass_cred: AtomicBool::new(false), + cred: peer_cred, }; + if let Some(addr) = addr.as_ref() { + this_inner.addr.call_once(|| addr.clone().into()); + } + if let Some(peer_addr) = peer_addr.as_ref() { + peer_inner.addr.call_once(|| peer_addr.clone().into()); + } + let (this_inner, peer_inner) = Endpoint::new_pair(this_inner, peer_inner); let this = Connected { + addr, inner: this_inner, - reader: Mutex::new(this_reader), - writer: Mutex::new(this_writer), - peer_cred, }; let peer = Connected { + addr: peer_addr, inner: peer_inner, - reader: Mutex::new(peer_reader), - writer: Mutex::new(peer_writer), - peer_cred: cred, }; (this, peer) } - pub(super) fn addr(&self) -> Option { - self.inner.this_end().addr.lock().clone() + pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { + self.addr.as_ref() } - pub(super) fn peer_addr(&self) -> Option { - self.inner.peer_end().addr.lock().clone() + pub(super) fn peer_addr(&self) -> UnixSocketAddr { + self.inner + .peer_end() + .addr + .get() + .cloned() + .unwrap_or(UnixSocketAddr::Unnamed) } - pub(super) fn bind(&self, addr_to_bind: UnixSocketAddr) -> Result<()> { - let mut addr = self.inner.this_end().addr.lock(); - - if addr.is_some() { + pub(super) fn bind(&mut self, addr_to_bind: UnixSocketAddr) -> Result<()> { + if self.addr.is_some() { return addr_to_bind.bind_unnamed(); } let bound_addr = addr_to_bind.bind()?; - *addr = Some(bound_addr); + self.inner + .this_end() + .addr + .call_once(|| bound_addr.clone().into()); + self.addr = Some(bound_addr); Ok(()) } @@ -103,19 +121,21 @@ impl Connected { ) -> Result<(usize, Vec)> { let is_empty = writer.is_empty(); if is_empty && !is_seqpacket { - if self.reader.lock().is_empty() { + if self.inner.this_end().reader.lock().is_empty() { return_errno_with_message!(Errno::EAGAIN, "the channel is empty"); } return Ok((0, Vec::new())); } - let mut reader = self.reader.lock(); + let this_end = self.inner.this_end(); + let peer_end = self.inner.peer_end(); + + let mut reader = this_end.reader.lock(); // `reader.len()` is an `Acquire` operation. So it can guarantee that the `has_aux` // check below sees the up-to-date value. let no_aux_len = reader.len(); - let peer_end = self.inner.peer_end(); - let is_pass_cred = self.inner.this_end().is_pass_cred.load(Ordering::Relaxed); + let is_pass_cred = this_end.is_pass_cred.load(Ordering::Relaxed); // Fast path: There are no auxiliary data to receive. if !peer_end.has_aux.load(Ordering::Relaxed) { @@ -233,7 +253,7 @@ impl Connected { // Fast path: There are no auxiliary data to transmit. if aux_data.is_empty() && !is_seqpacket && !need_pass_cred { - let mut writer = self.writer.lock(); + let mut writer = this_end.writer.lock(); return self.inner.write_with(move || { if is_seqpacket && writer.free_len() < reader.sum_lens() { return Ok(0); @@ -250,7 +270,7 @@ impl Connected { // Write the payload bytes. let (write_start, write_res) = if !is_empty { - let mut writer = self.writer.lock(); + let mut writer = this_end.writer.lock(); let write_start = writer.tail(); let write_res = self.inner.write_with(move || { if is_seqpacket && writer.free_len() < reader.sum_lens() { @@ -260,7 +280,7 @@ impl Connected { }); (write_start, write_res) } else { - (self.writer.lock().tail(), Ok(0)) + (this_end.writer.lock().tail(), Ok(0)) }; let Ok(write_len) = write_res else { this_end @@ -310,13 +330,14 @@ impl Connected { } pub(super) fn check_io_events(&self) -> IoEvents { + let this_end = self.inner.this_end(); let mut events = IoEvents::empty(); - if !self.reader.lock().is_empty() { + if !this_end.reader.lock().is_empty() { events |= IoEvents::IN; } - if !self.writer.lock().is_full() { + if !this_end.writer.lock().is_full() { events |= IoEvents::OUT; } @@ -328,7 +349,7 @@ impl Connected { } pub(super) fn peer_cred(&self) -> &SocketCred { - &self.peer_cred + &self.inner.peer_end().cred } } @@ -340,12 +361,15 @@ impl Drop for Connected { } struct Inner { - addr: SpinLock>, + addr: Once, state: EndpointState, - is_pass_cred: AtomicBool, // Lock order: `reader` -> `all_aux` & `all_aux` -> `writer` + reader: Mutex>, + writer: Mutex>, all_aux: Mutex>, has_aux: AtomicBool, + is_pass_cred: AtomicBool, + cred: SocketCred, } impl AsRef for Inner { diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index 476cb0765..2c7ebccce 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -362,7 +362,7 @@ impl Socket for UnixStreamSocket { let addr = match self.state.read().as_ref() { State::Init(init) => init.addr().cloned(), State::Listen(listen) => Some(listen.addr().clone()), - State::Connected(connected) => connected.addr(), + State::Connected(connected) => connected.addr().cloned(), }; Ok(addr.into()) diff --git a/test/src/apps/network/unix_streamlike_prologue.h b/test/src/apps/network/unix_streamlike_prologue.h index 4261d93d2..e0d885f84 100644 --- a/test/src/apps/network/unix_streamlike_prologue.h +++ b/test/src/apps/network/unix_streamlike_prologue.h @@ -238,11 +238,12 @@ END_TEST() FN_TEST(bind_connected) { - int fildes[2]; + int fildes[2], sk; struct sockaddr_un addr; socklen_t addrlen; TEST_SUCC(socketpair(PF_UNIX, SOCK_TYPE, 0, fildes)); + sk = TEST_SUCC(socket(PF_UNIX, SOCK_TYPE, 0)); TEST_SUCC(bind(fildes[0], (struct sockaddr *)&UNIX_ADDR("\0X"), PATH_OFFSET + 2)); @@ -269,8 +270,24 @@ FN_TEST(bind_connected) TEST_SUCC(bind(fildes[1], (struct sockaddr *)&UNNAMED_ADDR, UNNAMED_ADDRLEN)); + // Closing the socket will release the bound address. + // So another socket can bind to it again. + TEST_ERRNO(bind(sk, (struct sockaddr *)&UNIX_ADDR("\0X"), + PATH_OFFSET + 2), + EADDRINUSE); TEST_SUCC(close(fildes[0])); + TEST_SUCC(bind(sk, (struct sockaddr *)&UNIX_ADDR("\0X"), + PATH_OFFSET + 2)); + + // But the released address is still "visible" from + // the previously connected socket. + addrlen = sizeof(addr); + TEST_RES(getpeername(fildes[1], (struct sockaddr *)&addr, &addrlen), + addrlen == PATH_OFFSET + 2 && memcmp(&addr, &UNIX_ADDR("\0X"), + PATH_OFFSET + 2) == 0); + TEST_SUCC(close(fildes[1])); + TEST_SUCC(close(sk)); } END_TEST()