diff --git a/kernel/src/fs/pipe.rs b/kernel/src/fs/pipe.rs index 0a933f8f2..f593b68ba 100644 --- a/kernel/src/fs/pipe.rs +++ b/kernel/src/fs/pipe.rs @@ -4,7 +4,7 @@ use core::sync::atomic::{AtomicU32, Ordering}; use super::{ file_handle::FileLike, - utils::{AccessMode, Channel, Consumer, InodeMode, InodeType, Metadata, Producer, StatusFlags}, + utils::{AccessMode, Endpoint, EndpointState, InodeMode, InodeType, Metadata, StatusFlags}, }; use crate::{ events::IoEvents, @@ -14,47 +14,85 @@ use crate::{ Gid, Uid, }, time::clocks::RealTimeCoarseClock, + util::ring_buffer::{RbConsumer, RbProducer, RingBuffer}, }; const DEFAULT_PIPE_BUF_SIZE: usize = 65536; -pub fn new_pair() -> Result<(Arc, Arc)> { - let (producer, consumer) = Channel::with_capacity(DEFAULT_PIPE_BUF_SIZE).split(); +/// Maximum number of bytes guaranteed to be written to a pipe atomically. +/// +/// If the number of bytes to be written is less than the threshold, the write must be atomic. +/// A non-blocking atomic write may fail with `EAGAIN`, even if there is room for a partial write. +/// In other words, a partial write is not allowed for an atomic write. +/// +/// For more details, see the description of `PIPE_BUF` in +/// . +#[cfg(not(ktest))] +const PIPE_BUF: usize = 4096; +#[cfg(ktest)] +const PIPE_BUF: usize = 2; - Ok(( - PipeReader::new(consumer, StatusFlags::empty())?, - PipeWriter::new(producer, StatusFlags::empty())?, - )) +pub fn new_pair() -> Result<(Arc, Arc)> { + new_pair_with_capacity(DEFAULT_PIPE_BUF_SIZE) } pub fn new_pair_with_capacity(capacity: usize) -> Result<(Arc, Arc)> { - let (producer, consumer) = Channel::with_capacity(capacity).split(); + let (producer, consumer) = RingBuffer::new(capacity).split(); + let (producer_state, consumer_state) = + Endpoint::new_pair(EndpointState::default(), EndpointState::default()); Ok(( - PipeReader::new(consumer, StatusFlags::empty())?, - PipeWriter::new(producer, StatusFlags::empty())?, + PipeReader::new(consumer, consumer_state, StatusFlags::empty())?, + PipeWriter::new(producer, producer_state, StatusFlags::empty())?, )) } pub struct PipeReader { - consumer: Consumer, + consumer: Mutex>, + state: Endpoint, status_flags: AtomicU32, } impl PipeReader { - pub fn new(consumer: Consumer, status_flags: StatusFlags) -> Result> { + fn new( + consumer: RbConsumer, + state: Endpoint, + status_flags: StatusFlags, + ) -> Result> { check_status_flags(status_flags)?; Ok(Arc::new(Self { - consumer, + consumer: Mutex::new(consumer), + state, status_flags: AtomicU32::new(status_flags.bits()), })) } + + fn try_read(&self, writer: &mut VmWriter) -> Result { + let read = || { + let mut consumer = self.consumer.lock(); + consumer.read_fallible(writer) + }; + + self.state.read_with(read) + } + + fn check_io_events(&self) -> IoEvents { + let mut events = IoEvents::empty(); + if self.state.is_peer_shutdown() { + events |= IoEvents::HUP; + } + if !self.consumer.lock().is_empty() { + events |= IoEvents::IN; + } + events + } } impl Pollable for PipeReader { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.consumer.poll(mask, poller) + self.state + .poll_with(mask, poller, || self.check_io_events()) } } @@ -67,9 +105,9 @@ impl FileLike for PipeReader { } if self.status_flags().contains(StatusFlags::O_NONBLOCK) { - self.consumer.try_read(writer) + self.try_read(writer) } else { - self.wait_events(IoEvents::IN, None, || self.consumer.try_read(writer)) + self.wait_events(IoEvents::IN, None, || self.try_read(writer)) } } @@ -111,25 +149,61 @@ impl FileLike for PipeReader { } } +impl Drop for PipeReader { + fn drop(&mut self) { + self.state.peer_shutdown(); + } +} + pub struct PipeWriter { - producer: Producer, + producer: Mutex>, + state: Endpoint, status_flags: AtomicU32, } impl PipeWriter { - pub fn new(producer: Producer, status_flags: StatusFlags) -> Result> { + fn new( + producer: RbProducer, + state: Endpoint, + status_flags: StatusFlags, + ) -> Result> { check_status_flags(status_flags)?; Ok(Arc::new(Self { - producer, + producer: Mutex::new(producer), + state, status_flags: AtomicU32::new(status_flags.bits()), })) } + + fn try_write(&self, reader: &mut VmReader) -> Result { + let write = || { + let mut producer = self.producer.lock(); + if reader.remain() <= PIPE_BUF && producer.free_len() < reader.remain() { + // No sufficient space for an atomic write + return Ok(0); + } + producer.write_fallible(reader) + }; + + self.state.write_with(write) + } + + fn check_io_events(&self) -> IoEvents { + if self.state.is_shutdown() { + IoEvents::ERR | IoEvents::OUT + } else if self.producer.lock().free_len() >= PIPE_BUF { + IoEvents::OUT + } else { + IoEvents::empty() + } + } } impl Pollable for PipeWriter { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.producer.poll(mask, poller) + self.state + .poll_with(mask, poller, || self.check_io_events()) } } @@ -142,9 +216,9 @@ impl FileLike for PipeWriter { } if self.status_flags().contains(StatusFlags::O_NONBLOCK) { - self.producer.try_write(reader) + self.try_write(reader) } else { - self.wait_events(IoEvents::OUT, None, || self.producer.try_write(reader)) + self.wait_events(IoEvents::OUT, None, || self.try_write(reader)) } } @@ -201,6 +275,12 @@ fn check_status_flags(status_flags: StatusFlags) -> Result<()> { Ok(()) } +impl Drop for PipeWriter { + fn drop(&mut self) { + self.state.shutdown(); + } +} + #[cfg(ktest)] mod test { use alloc::sync::Arc; @@ -209,10 +289,7 @@ mod test { use ostd::prelude::*; use super::*; - use crate::{ - fs::utils::Channel, - thread::{kernel_thread::ThreadOptions, Thread}, - }; + use crate::thread::{kernel_thread::ThreadOptions, Thread}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum Ordering { @@ -225,11 +302,7 @@ mod test { W: FnOnce(Arc) + Send + 'static, R: FnOnce(Arc) + Send + 'static, { - let channel = Channel::with_capacity(2); - let (writer, readr) = channel.split(); - - let writer = PipeWriter::new(writer, StatusFlags::empty()).unwrap(); - let reader = PipeReader::new(readr, StatusFlags::empty()).unwrap(); + let (reader, writer) = new_pair_with_capacity(2).unwrap(); let signal_writer = Arc::new(AtomicBool::new(false)); let signal_reader = signal_writer.clone(); diff --git a/kernel/src/fs/utils/channel.rs b/kernel/src/fs/utils/channel.rs deleted file mode 100644 index 97944b202..000000000 --- a/kernel/src/fs/utils/channel.rs +++ /dev/null @@ -1,436 +0,0 @@ -// SPDX-License-Identifier: MPL-2.0 - -use core::sync::atomic::{AtomicBool, Ordering}; - -use aster_rights::{Read, ReadOp, TRights, Write, WriteOp}; -use aster_rights_proc::require; - -use crate::{ - events::IoEvents, - prelude::*, - process::signal::{PollHandle, Pollee}, - util::{ - ring_buffer::{RbConsumer, RbProducer, RingBuffer}, - MultiRead, MultiWrite, - }, -}; - -/// A unidirectional communication channel, intended to implement IPC, e.g., pipe, -/// unix domain sockets, etc. -pub struct Channel { - producer: Producer, - consumer: Consumer, -} - -/// Maximum number of bytes guaranteed to be written to a pipe atomically. -/// -/// If the number of bytes to be written is less than the threshold, the write must be atomic. -/// A non-blocking atomic write may fail with `EAGAIN`, even if there is room for a partial write. -/// In other words, a partial write is not allowed for an atomic write. -/// -/// For more details, see the description of `PIPE_BUF` in -/// . -#[cfg(not(ktest))] -const PIPE_BUF: usize = 4096; -#[cfg(ktest)] -const PIPE_BUF: usize = 2; - -impl Channel { - /// Creates a new channel with the given capacity. - /// - /// # Panics - /// - /// This method will panic if the given capacity is zero. - pub fn with_capacity(capacity: usize) -> Self { - Self::with_capacity_and_pollees(capacity, None, None) - } - - /// Creates a new channel with the given capacity and pollees. - /// - /// # Panics - /// - /// This method will panic if the given capacity is zero. - pub fn with_capacity_and_pollees( - capacity: usize, - producer_pollee: Option, - consumer_pollee: Option, - ) -> Self { - let common = Arc::new(Common::new(capacity, producer_pollee, consumer_pollee)); - - let producer = Producer(Fifo::new(common.clone())); - let consumer = Consumer(Fifo::new(common)); - - Self { producer, consumer } - } - - pub fn split(self) -> (Producer, Consumer) { - let Self { producer, consumer } = self; - (producer, consumer) - } - - pub fn producer(&self) -> &Producer { - &self.producer - } - - pub fn consumer(&self) -> &Consumer { - &self.consumer - } - - pub fn capacity(&self) -> usize { - self.producer.0.common.capacity() - } -} - -pub struct Producer(Fifo); - -pub struct Consumer(Fifo); - -macro_rules! impl_common_methods_for_channel { - () => { - pub fn shutdown(&self) { - self.0.common.shutdown() - } - - pub fn is_shutdown(&self) -> bool { - self.0.common.is_shutdown() - } - - pub fn is_full(&self) -> bool { - self.this_end().rb().is_full() - } - - pub fn is_empty(&self) -> bool { - self.this_end().rb().is_empty() - } - - pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.this_end() - .pollee - .poll_with(mask, poller, || self.check_io_events()) - } - }; -} - -impl Producer { - fn this_end(&self) -> &FifoInner> { - &self.0.common.producer - } - - fn peer_end(&self) -> &FifoInner> { - &self.0.common.consumer - } - - fn check_io_events(&self) -> IoEvents { - let this_end = self.this_end(); - let rb = this_end.rb(); - - if self.is_shutdown() { - IoEvents::ERR | IoEvents::OUT - } else if rb.free_len() > PIPE_BUF { - IoEvents::OUT - } else { - IoEvents::empty() - } - } - - impl_common_methods_for_channel!(); -} - -impl Producer { - /// Tries to write `buf` to the channel. - /// - /// - Returns `Ok(_)` with the number of bytes written if successful. - /// - Returns `Err(EPIPE)` if the channel is shut down. - /// - Returns `Err(EAGAIN)` if the channel is full. - /// - /// The caller should not pass an empty `reader` to this method. - pub fn try_write(&self, reader: &mut dyn MultiRead) -> Result { - debug_assert!(!reader.is_empty()); - - if self.is_shutdown() { - return_errno_with_message!(Errno::EPIPE, "the channel is shut down"); - } - - let written_len = self.0.write(reader)?; - self.peer_end().pollee.notify(IoEvents::IN); - self.this_end().pollee.invalidate(); - - if written_len > 0 { - Ok(written_len) - } else { - return_errno_with_message!(Errno::EAGAIN, "the channel is full"); - } - } -} - -impl Producer { - /// Tries to push `item` to the channel. - /// - /// - Returns `Ok(())` if successful. - /// - Returns `Err(EPIPE)` if the channel is shut down. - /// - Returns `Err(EAGAIN)` if the channel is full. - pub fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> { - if self.is_shutdown() { - let err = Error::with_message(Errno::EPIPE, "the channel is shut down"); - return Err((err, item)); - } - - self.0.push(item).map_err(|item| { - let err = Error::with_message(Errno::EAGAIN, "the channel is full"); - (err, item) - })?; - self.peer_end().pollee.notify(IoEvents::IN); - self.this_end().pollee.invalidate(); - - Ok(()) - } -} - -impl Drop for Producer { - fn drop(&mut self) { - self.shutdown(); - } -} - -impl Consumer { - fn this_end(&self) -> &FifoInner> { - &self.0.common.consumer - } - - fn peer_end(&self) -> &FifoInner> { - &self.0.common.producer - } - - fn check_io_events(&self) -> IoEvents { - let this_end = self.this_end(); - let rb = this_end.rb(); - - let mut events = IoEvents::empty(); - if self.is_shutdown() { - events |= IoEvents::HUP; - } - if !rb.is_empty() { - events |= IoEvents::IN; - } - events - } - - impl_common_methods_for_channel!(); -} - -impl Consumer { - /// Tries to read `buf` from the channel. - /// - /// - Returns `Ok(_)` with the number of bytes read if successful. - /// - Returns `Ok(0)` if the channel is shut down and there is no data left. - /// - Returns `Err(EAGAIN)` if the channel is empty. - /// - /// The caller should not pass an empty `writer` to this method. - pub fn try_read(&self, writer: &mut dyn MultiWrite) -> Result { - debug_assert!(!writer.is_empty()); - - // This must be recorded before the actual operation to avoid race conditions. - let is_shutdown = self.is_shutdown(); - - let read_len = self.0.read(writer)?; - self.peer_end().pollee.notify(IoEvents::OUT); - self.this_end().pollee.invalidate(); - - if read_len > 0 { - Ok(read_len) - } else if is_shutdown { - Ok(0) - } else { - return_errno_with_message!(Errno::EAGAIN, "the channel is empty"); - } - } -} - -impl Consumer { - /// Tries to read an item from the channel. - /// - /// - Returns `Ok(Some(_))` with the popped item if successful. - /// - Returns `Ok(None)` if the channel is shut down and there is no data left. - /// - Returns `Err(EAGAIN)` if the channel is empty. - pub fn try_pop(&self) -> Result> { - // This must be recorded before the actual operation to avoid race conditions. - let is_shutdown = self.is_shutdown(); - - let item = self.0.pop(); - self.peer_end().pollee.notify(IoEvents::OUT); - self.this_end().pollee.invalidate(); - - if let Some(item) = item { - Ok(Some(item)) - } else if is_shutdown { - Ok(None) - } else { - return_errno_with_message!(Errno::EAGAIN, "the channel is empty") - } - } -} - -impl Drop for Consumer { - fn drop(&mut self) { - self.shutdown(); - } -} - -struct Fifo { - common: Arc>, - _rights: R, -} - -impl Fifo { - pub fn new(common: Arc>) -> Self { - Self { - common, - _rights: R::new(), - } - } -} - -impl Fifo { - #[require(R > Read)] - pub fn read(&self, writer: &mut dyn MultiWrite) -> Result { - let mut rb = self.common.consumer.rb(); - rb.read_fallible(writer) - } - - #[require(R > Write)] - pub fn write(&self, reader: &mut dyn MultiRead) -> Result { - let mut rb = self.common.producer.rb(); - if rb.free_len() < reader.sum_lens() && reader.sum_lens() <= PIPE_BUF { - // No sufficient space for an atomic write - return Ok(0); - } - rb.write_fallible(reader) - } -} - -impl Fifo { - /// Pushes an item into the endpoint. - /// If the `push` method fails, this method will return - /// `Err` containing the item that hasn't been pushed - #[require(R > Write)] - pub fn push(&self, item: T) -> core::result::Result<(), T> { - let mut rb = self.common.producer.rb(); - rb.push(item).ok_or(item) - } - - /// Pops an item from the endpoint. - #[require(R > Read)] - pub fn pop(&self) -> Option { - let mut rb = self.common.consumer.rb(); - rb.pop() - } -} - -struct Common { - producer: FifoInner>, - consumer: FifoInner>, - is_shutdown: AtomicBool, -} - -impl Common { - fn new( - capacity: usize, - producer_pollee: Option, - consumer_pollee: Option, - ) -> Self { - let rb: RingBuffer = RingBuffer::new(capacity); - let (rb_producer, rb_consumer) = rb.split(); - - let producer = { - let pollee = producer_pollee - .inspect(|pollee| pollee.invalidate()) - .unwrap_or_default(); - FifoInner::new(rb_producer, pollee) - }; - - let consumer = { - let pollee = consumer_pollee - .inspect(|pollee| pollee.invalidate()) - .unwrap_or_default(); - FifoInner::new(rb_consumer, pollee) - }; - - Self { - producer, - consumer, - is_shutdown: AtomicBool::new(false), - } - } - - pub fn capacity(&self) -> usize { - self.producer.rb().capacity() - } - - pub fn is_shutdown(&self) -> bool { - self.is_shutdown.load(Ordering::Relaxed) - } - - pub fn shutdown(&self) { - if self.is_shutdown.swap(true, Ordering::Relaxed) { - return; - } - - // The POLLHUP event indicates that the write end is shut down. - self.consumer.pollee.notify(IoEvents::HUP); - - // The POLLERR event indicates that the read end is shut down (so any subsequent writes - // will fail with an `EPIPE` error). - self.producer.pollee.notify(IoEvents::ERR | IoEvents::OUT); - } -} - -struct FifoInner { - rb: Mutex, - pollee: Pollee, -} - -impl FifoInner { - pub fn new(rb: T, pollee: Pollee) -> Self { - Self { - rb: Mutex::new(rb), - pollee, - } - } - - pub fn rb(&self) -> MutexGuard { - self.rb.lock() - } -} - -#[cfg(ktest)] -mod test { - use ostd::prelude::*; - - use super::*; - - #[ktest] - fn test_channel_basics() { - let channel = Channel::with_capacity(16); - let (producer, consumer) = channel.split(); - - let data = [1u8, 3, 7]; - - for d in &data { - producer.try_push(*d).unwrap(); - } - for d in &data { - let popped = consumer.try_pop().unwrap().unwrap(); - assert_eq!(*d, popped); - } - - let mut expected_data = [0u8; 3]; - let write_len = producer - .try_write(&mut VmReader::from(data.as_slice()).to_fallible()) - .unwrap(); - assert_eq!(write_len, 3); - consumer - .try_read(&mut VmWriter::from(expected_data.as_mut_slice()).to_fallible()) - .unwrap(); - assert_eq!(data, expected_data); - } -} diff --git a/kernel/src/fs/utils/endpoint.rs b/kernel/src/fs/utils/endpoint.rs new file mode 100644 index 000000000..b79cc4838 --- /dev/null +++ b/kernel/src/fs/utils/endpoint.rs @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::sync::Arc; +use core::sync::atomic::{AtomicBool, Ordering}; + +use crate::{ + events::IoEvents, + prelude::*, + process::signal::{PollHandle, Pollee}, +}; + +/// One of two connected endpoints. +/// +/// There is a `T` on the local endpoint and another `T` on the remote endpoint. This type allows +/// users to access the local and remote `T`s from both endpoints. +pub struct Endpoint { + inner: Arc>, + endpoint: Location, +} + +enum Location { + Client, + Server, +} + +struct Inner { + client: T, + server: T, +} + +impl Endpoint { + /// Creates an instance pair with two `T`s on the two endpoints. + /// + /// For the first instance, `this` is on the local endpoint and `peer` is on the remote + /// endpoint; for the second instance, `this` is on the remote endpoint and `peer` is on the + /// local endpoint. + pub fn new_pair(this: T, peer: T) -> (Endpoint, Endpoint) { + let inner = Arc::new(Inner { + client: this, + server: peer, + }); + + let client = Endpoint { + inner: inner.clone(), + endpoint: Location::Client, + }; + let server = Endpoint { + inner, + endpoint: Location::Server, + }; + + (client, server) + } + + /// Returns a reference to the `T` on the local endpoint. + pub fn this_end(&self) -> &T { + match self.endpoint { + Location::Client => &self.inner.client, + Location::Server => &self.inner.server, + } + } + + /// Returns a reference to the `T` on the remote endpoint. + pub fn peer_end(&self) -> &T { + match self.endpoint { + Location::Client => &self.inner.server, + Location::Server => &self.inner.client, + } + } +} + +/// A [`Endpoint`] state that helps end-to-end data communication. +/// +/// The state contains a [`Pollee`] that will be notified when new data or the buffer becomes +/// available. Additionally, this state tracks whether communication has been shut down, i.e., +/// whether further data writing is disallowed. +/// +/// By having [`EndpointState`] as a part of the endpoint data (i.e., `T` in [`Endpoint`] should +/// implement [`AsRef`]), methods like [`Endpoint::read_with`], +/// [`Endpoint::write_with`], and [`Endpoint::poll_with`] are available for performing data +/// transmission and registering event observers. +/// +/// The data communication can be unidirectional, such as pipes, or bidirectional, such as UNIX +/// sockets. +pub struct EndpointState { + pollee: Pollee, + is_shutdown: AtomicBool, +} + +impl EndpointState { + /// Creates with the [`Pollee`] and the shutdown status. + pub fn new(pollee: Pollee, is_shutdown: bool) -> Self { + Self { + pollee, + is_shutdown: AtomicBool::new(is_shutdown), + } + } + + /// Clones and returns the [`Pollee`]. + /// + /// Do not use this method to perform cheap operations on the [`Pollee`] (e.g., + /// [`Pollee::notify`]). Use the methods below, such as [`read_with`]/[`write_with`], instead. + /// This method is deliberately designed to force the [`Pollee`] to be cloned to avoid such + /// misuse. + /// + /// [`read_with`]: Endpoint::read_with + /// [`write_with`]: Endpoint::read_with + pub fn cloned_pollee(&self) -> Pollee { + self.pollee.clone() + } +} + +impl Default for EndpointState { + fn default() -> Self { + Self::new(Pollee::new(), false) + } +} + +impl AsRef for EndpointState { + fn as_ref(&self) -> &EndpointState { + self + } +} + +impl> Endpoint { + /// Reads from the endpoint and updates the local/remote [`Pollee`]s. + /// + /// Note that if `read` returns `Ok(0)`, it is assumed that there is no data to read and an + /// [`Errno::EAGAIN`] error will be returned instead. + /// + /// However, if the remote endpoint has shut down, this method will return `Ok(0)` to indicate + /// the end-of-file (EOF). + pub fn read_with(&self, read: F) -> Result + where + F: FnOnce() -> Result, + { + // This must be recorded before the actual operation to avoid race conditions. + let is_shutdown = self.is_peer_shutdown(); + + let read_len = read()?; + + if read_len > 0 { + self.peer_end().as_ref().pollee.notify(IoEvents::OUT); + self.this_end().as_ref().pollee.invalidate(); + Ok(read_len) + } else if is_shutdown { + Ok(0) + } else { + return_errno_with_message!(Errno::EAGAIN, "the channel is empty"); + } + } + + /// Writes to the endpoint and updates the local/remote [`Pollee`]s. + /// + /// Note that if `write` returns `Ok(0)`, it is assumed that there is no space to write and an + /// [`Errno::EAGAIN`] error will be returned instead. + /// + /// If the local endpoint has shut down, this method will return an [`Errno::EPIPE`] error + /// directly without calling the `write` closure. + pub fn write_with(&self, write: F) -> Result + where + F: FnOnce() -> Result, + { + if self.is_shutdown() { + return_errno_with_message!(Errno::EPIPE, "the channel is shut down"); + } + + let written_len = write()?; + + if written_len > 0 { + self.peer_end().as_ref().pollee.notify(IoEvents::IN); + self.this_end().as_ref().pollee.invalidate(); + Ok(written_len) + } else { + return_errno_with_message!(Errno::EAGAIN, "the channel is full"); + } + } + + /// Polls the I/O events in the local [`Pollee`]. + pub fn poll_with( + &self, + mask: IoEvents, + poller: Option<&mut PollHandle>, + check: F, + ) -> IoEvents + where + F: FnOnce() -> IoEvents, + { + self.this_end() + .as_ref() + .pollee + .poll_with(mask, poller, check) + } + + /// Shuts down the local endpoint. + /// + /// After this method, data cannot be written to from the local endpoint. + pub fn shutdown(&self) { + let this_end = self.this_end().as_ref(); + let peer_end = self.peer_end().as_ref(); + + Self::shutdown_impl(this_end, peer_end); + } + + /// Shuts down the remote endpoint. + /// + /// After this method, data cannot be written to from the remote endpoint. + pub fn peer_shutdown(&self) { + let this_end = self.this_end().as_ref(); + let peer_end = self.peer_end().as_ref(); + + Self::shutdown_impl(peer_end, this_end); + } + + fn shutdown_impl(this_end: &EndpointState, peer_end: &EndpointState) { + this_end.is_shutdown.store(true, Ordering::Relaxed); + peer_end + .pollee + .notify(IoEvents::HUP | IoEvents::RDHUP | IoEvents::IN); + this_end + .pollee + .notify(IoEvents::HUP | IoEvents::ERR | IoEvents::OUT); + } + + /// Returns whether the local endpoint has shut down. + pub fn is_shutdown(&self) -> bool { + self.this_end().as_ref().is_shutdown.load(Ordering::Relaxed) + } + + /// Returns whether the remote endpoint has shut down. + pub fn is_peer_shutdown(&self) -> bool { + self.peer_end().as_ref().is_shutdown.load(Ordering::Relaxed) + } +} diff --git a/kernel/src/fs/utils/mod.rs b/kernel/src/fs/utils/mod.rs index da3d916f2..80855200c 100644 --- a/kernel/src/fs/utils/mod.rs +++ b/kernel/src/fs/utils/mod.rs @@ -3,10 +3,10 @@ //! VFS components pub use access_mode::AccessMode; -pub use channel::{Channel, Consumer, Producer}; pub use creation_flags::CreationFlags; pub use dirent_visitor::DirentVisitor; pub use direntry_vec::DirEntryVecExt; +pub use endpoint::{Endpoint, EndpointState}; pub use falloc_mode::FallocMode; pub use file_creation_mask::FileCreationMask; pub use flock::{FlockItem, FlockList, FlockType}; @@ -25,10 +25,10 @@ pub use xattr::{ }; mod access_mode; -mod channel; mod creation_flags; mod dirent_visitor; mod direntry_vec; +mod endpoint; mod falloc_mode; mod file_creation_mask; mod flock; diff --git a/kernel/src/net/socket/unix/stream/connected.rs b/kernel/src/net/socket/unix/stream/connected.rs index 70ce4a37f..7dca81d8c 100644 --- a/kernel/src/net/socket/unix/stream/connected.rs +++ b/kernel/src/net/socket/unix/stream/connected.rs @@ -1,65 +1,71 @@ // SPDX-License-Identifier: MPL-2.0 -use core::ops::Deref; - -use ostd::sync::PreemptDisabled; - use crate::{ events::IoEvents, - fs::utils::{Channel, Consumer, Producer}, + fs::utils::{Endpoint, EndpointState}, net::socket::{ unix::{addr::UnixSocketAddrBound, UnixSocketAddr}, util::SockShutdownCmd, }, prelude::*, - process::signal::{PollHandle, Pollee}, - util::{MultiRead, MultiWrite}, + process::signal::Pollee, + util::{ + ring_buffer::{RbConsumer, RbProducer, RingBuffer}, + MultiRead, MultiWrite, + }, }; pub(super) struct Connected { - addr: AddrView, - reader: Consumer, - writer: Producer, + inner: Endpoint, + reader: Mutex>, + writer: Mutex>, } impl Connected { pub(super) fn new_pair( addr: Option, peer_addr: Option, - reader_pollee: Option, - writer_pollee: Option, + state: EndpointState, + peer_state: EndpointState, ) -> (Connected, Connected) { - let (writer_peer, reader_this) = - Channel::with_capacity_and_pollees(DEFAULT_BUF_SIZE, None, reader_pollee).split(); - let (writer_this, reader_peer) = - Channel::with_capacity_and_pollees(DEFAULT_BUF_SIZE, writer_pollee, None).split(); + let (this_writer, peer_reader) = RingBuffer::new(DEFAULT_BUF_SIZE).split(); + let (peer_writer, this_reader) = RingBuffer::new(DEFAULT_BUF_SIZE).split(); - let (addr_this, addr_peer) = AddrView::new_pair(addr, peer_addr); + let this_inner = Inner { + addr: SpinLock::new(addr), + state, + }; + let peer_inner = Inner { + addr: SpinLock::new(peer_addr), + state: peer_state, + }; + + let (this_inner, peer_inner) = Endpoint::new_pair(this_inner, peer_inner); let this = Connected { - addr: addr_this, - reader: reader_this, - writer: writer_this, + inner: this_inner, + reader: Mutex::new(this_reader), + writer: Mutex::new(this_writer), }; let peer = Connected { - addr: addr_peer, - reader: reader_peer, - writer: writer_peer, + inner: peer_inner, + reader: Mutex::new(peer_reader), + writer: Mutex::new(peer_writer), }; (this, peer) } pub(super) fn addr(&self) -> Option { - self.addr.addr().deref().as_ref().cloned() + self.inner.this_end().addr.lock().clone() } pub(super) fn peer_addr(&self) -> Option { - self.addr.peer_addr() + self.inner.peer_end().addr.lock().clone() } pub(super) fn bind(&self, addr_to_bind: UnixSocketAddr) -> Result<()> { - let mut addr = self.addr.addr(); + let mut addr = self.inner.this_end().addr.lock(); if addr.is_some() { return addr_to_bind.bind_unnamed(); @@ -73,100 +79,88 @@ impl Connected { pub(super) fn try_read(&self, writer: &mut dyn MultiWrite) -> Result { if writer.is_empty() { - if self.reader.is_empty() { + if self.reader.lock().is_empty() { return_errno_with_message!(Errno::EAGAIN, "the channel is empty"); } return Ok(0); } - self.reader.try_read(writer) + let read = || { + let mut reader = self.reader.lock(); + reader.read_fallible(writer) + }; + + self.inner.read_with(read) } pub(super) fn try_write(&self, reader: &mut dyn MultiRead) -> Result { if reader.is_empty() { - if self.writer.is_shutdown() { + if self.inner.is_shutdown() { return_errno_with_message!(Errno::EPIPE, "the channel is shut down"); } return Ok(0); } - self.writer.try_write(reader) + let write = || { + let mut writer = self.writer.lock(); + writer.write_fallible(reader) + }; + + self.inner.write_with(write) } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { if cmd.shut_read() { - self.reader.shutdown(); + self.inner.peer_shutdown(); } if cmd.shut_write() { - self.writer.shutdown(); + self.inner.shutdown(); } } - pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents { - // Note that `mask | IoEvents::ALWAYS_POLL` contains all the events we care about. - let reader_events = self.reader.poll(mask, poller.as_deref_mut()); - let writer_events = self.writer.poll(mask, poller); - - combine_io_events(mask, reader_events, writer_events) + pub(super) fn is_read_shutdown(&self) -> bool { + self.inner.is_peer_shutdown() } -} -pub(super) fn combine_io_events( - mask: IoEvents, - reader_events: IoEvents, - writer_events: IoEvents, -) -> IoEvents { - let mut events = IoEvents::empty(); + pub(super) fn is_write_shutdown(&self) -> bool { + self.inner.is_shutdown() + } - if reader_events.contains(IoEvents::HUP) { - // The socket is shut down in one direction: the remote socket has shut down for - // writing or the local socket has shut down for reading. - events |= IoEvents::RDHUP | IoEvents::IN; + pub(super) fn check_io_events(&self) -> IoEvents { + let mut events = IoEvents::empty(); - if writer_events.contains(IoEvents::ERR) { - // The socket is shut down in both directions. Neither reading nor writing is - // possible. - events |= IoEvents::HUP; + if !self.reader.lock().is_empty() { + events |= IoEvents::IN; } + + if !self.writer.lock().is_full() { + events |= IoEvents::OUT; + } + + events } - events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); - - events & (mask | IoEvents::ALWAYS_POLL) + pub(super) fn cloned_pollee(&self) -> Pollee { + self.inner.this_end().state.cloned_pollee() + } } -struct AddrView { - addr: Arc>>, - peer: Arc>>, +impl Drop for Connected { + fn drop(&mut self) { + self.inner.shutdown(); + self.inner.peer_shutdown(); + } } -impl AddrView { - fn new_pair( - first: Option, - second: Option, - ) -> (AddrView, AddrView) { - let first = Arc::new(SpinLock::new(first)); - let second = Arc::new(SpinLock::new(second)); +struct Inner { + addr: SpinLock>, + state: EndpointState, +} - let view1 = AddrView { - addr: first.clone(), - peer: second.clone(), - }; - let view2 = AddrView { - addr: second, - peer: first, - }; - - (view1, view2) - } - - fn addr(&self) -> SpinLockGuard, PreemptDisabled> { - self.addr.lock() - } - - fn peer_addr(&self) -> Option { - self.peer.lock().as_ref().cloned() +impl AsRef for Inner { + fn as_ref(&self) -> &EndpointState { + &self.state } } diff --git a/kernel/src/net/socket/unix/stream/init.rs b/kernel/src/net/socket/unix/stream/init.rs index 2fc6b6609..dff846946 100644 --- a/kernel/src/net/socket/unix/stream/init.rs +++ b/kernel/src/net/socket/unix/stream/init.rs @@ -3,23 +3,23 @@ use core::sync::atomic::{AtomicBool, Ordering}; use super::{ - connected::{combine_io_events, Connected}, + connected::Connected, listener::Listener, + socket::{SHUT_READ_EVENTS, SHUT_WRITE_EVENTS}, }; use crate::{ events::IoEvents, + fs::utils::EndpointState, net::socket::{ unix::addr::{UnixSocketAddr, UnixSocketAddrBound}, util::SockShutdownCmd, }, prelude::*, - process::signal::{PollHandle, Pollee}, + process::signal::Pollee, }; pub(super) struct Init { addr: Option, - reader_pollee: Pollee, - writer_pollee: Pollee, is_read_shutdown: AtomicBool, is_write_shutdown: AtomicBool, } @@ -28,8 +28,6 @@ impl Init { pub(super) fn new() -> Self { Self { addr: None, - reader_pollee: Pollee::new(), - writer_pollee: Pollee::new(), is_read_shutdown: AtomicBool::new(false), is_write_shutdown: AtomicBool::new(false), } @@ -46,34 +44,33 @@ impl Init { Ok(()) } - pub(super) fn into_connected(self, peer_addr: UnixSocketAddrBound) -> (Connected, Connected) { + pub(super) fn into_connected( + self, + peer_addr: UnixSocketAddrBound, + pollee: Pollee, + ) -> (Connected, Connected) { let Init { addr, - reader_pollee, - writer_pollee, is_read_shutdown, is_write_shutdown, } = self; + pollee.invalidate(); let (this_conn, peer_conn) = Connected::new_pair( addr, Some(peer_addr), - Some(reader_pollee), - Some(writer_pollee), + EndpointState::new(pollee, is_read_shutdown.into_inner()), + EndpointState::new(Pollee::new(), is_write_shutdown.into_inner()), ); - if is_read_shutdown.into_inner() { - this_conn.shutdown(SockShutdownCmd::SHUT_RD); - } - - if is_write_shutdown.into_inner() { - this_conn.shutdown(SockShutdownCmd::SHUT_WR) - } - (this_conn, peer_conn) } - pub(super) fn listen(self, backlog: usize) -> core::result::Result { + pub(super) fn listen( + self, + backlog: usize, + pollee: Pollee, + ) -> core::result::Result { let Some(addr) = self.addr else { return Err(( Error::with_message(Errno::EINVAL, "the socket is not bound"), @@ -81,31 +78,25 @@ impl Init { )); }; + pollee.invalidate(); Ok(Listener::new( addr, - self.reader_pollee, - self.writer_pollee, backlog, self.is_read_shutdown.into_inner(), self.is_write_shutdown.into_inner(), + pollee, )) } - pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { - match cmd { - SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { - self.is_write_shutdown.store(true, Ordering::Relaxed); - self.writer_pollee.notify(IoEvents::ERR); - } - SockShutdownCmd::SHUT_RD => (), + pub(super) fn shutdown(&self, cmd: SockShutdownCmd, pollee: &Pollee) { + if cmd.shut_read() { + self.is_read_shutdown.store(true, Ordering::Relaxed); + pollee.notify(SHUT_READ_EVENTS); } - match cmd { - SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => { - self.is_read_shutdown.store(true, Ordering::Relaxed); - self.reader_pollee.notify(IoEvents::HUP); - } - SockShutdownCmd::SHUT_WR => (), + if cmd.shut_write() { + self.is_write_shutdown.store(true, Ordering::Relaxed); + pollee.notify(SHUT_WRITE_EVENTS); } } @@ -113,28 +104,17 @@ impl Init { self.addr.as_ref() } - pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents { - // To avoid loss of events, this must be compatible with - // `Connected::poll`/`Listener::poll`. - let reader_events = self - .reader_pollee - .poll_with(mask, poller.as_deref_mut(), || { - if self.is_read_shutdown.load(Ordering::Relaxed) { - IoEvents::HUP - } else { - IoEvents::empty() - } - }); - let writer_events = self.writer_pollee.poll_with(mask, poller, || { - if self.is_write_shutdown.load(Ordering::Relaxed) { - IoEvents::OUT | IoEvents::ERR - } else { - IoEvents::OUT - } - }); + pub(super) fn is_read_shutdown(&self) -> bool { + self.is_read_shutdown.load(Ordering::Relaxed) + } - // According to the Linux implementation, we always have `IoEvents::HUP` in this state. - // Meanwhile, it is in `IoEvents::ALWAYS_POLL`, so we always return it. - combine_io_events(mask, reader_events, writer_events) | IoEvents::HUP + pub(super) fn is_write_shutdown(&self) -> bool { + self.is_write_shutdown.load(Ordering::Relaxed) + } + + pub(super) fn check_io_events(&self) -> IoEvents { + // According to the Linux implementation, we always have `IoEvents::HUP` and + // `IoEvents::HUP` in this state. + IoEvents::OUT | IoEvents::HUP } } diff --git a/kernel/src/net/socket/unix/stream/listener.rs b/kernel/src/net/socket/unix/stream/listener.rs index 0c41659e5..77c2b4c95 100644 --- a/kernel/src/net/socket/unix/stream/listener.rs +++ b/kernel/src/net/socket/unix/stream/listener.rs @@ -5,8 +5,9 @@ use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use ostd::sync::WaitQueue; use super::{ - connected::{combine_io_events, Connected}, + connected::Connected, init::Init, + socket::{SHUT_READ_EVENTS, SHUT_WRITE_EVENTS}, UnixStreamSocket, }; use crate::{ @@ -17,33 +18,29 @@ use crate::{ util::{SockShutdownCmd, SocketAddr}, }, prelude::*, - process::signal::{PollHandle, Pollee}, + process::signal::Pollee, }; pub(super) struct Listener { backlog: Arc, is_write_shutdown: AtomicBool, - writer_pollee: Pollee, } impl Listener { pub(super) fn new( addr: UnixSocketAddrBound, - reader_pollee: Pollee, - writer_pollee: Pollee, backlog: usize, is_read_shutdown: bool, is_write_shutdown: bool, + pollee: Pollee, ) -> Self { let backlog = BACKLOG_TABLE - .add_backlog(addr, reader_pollee, backlog, is_read_shutdown) + .add_backlog(addr, pollee, backlog, is_read_shutdown) .unwrap(); - writer_pollee.invalidate(); Self { backlog, is_write_shutdown: AtomicBool::new(is_write_shutdown), - writer_pollee, } } @@ -63,35 +60,27 @@ impl Listener { self.backlog.set_backlog(backlog); } - pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { - match cmd { - SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { - self.is_write_shutdown.store(true, Ordering::Relaxed); - self.writer_pollee.notify(IoEvents::ERR); - } - SockShutdownCmd::SHUT_RD => (), + pub(super) fn shutdown(&self, cmd: SockShutdownCmd, pollee: &Pollee) { + if cmd.shut_read() { + self.backlog.shutdown(); } - match cmd { - SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => { - self.backlog.shutdown(); - } - SockShutdownCmd::SHUT_WR => (), + if cmd.shut_write() { + self.is_write_shutdown.store(true, Ordering::Relaxed); + pollee.notify(SHUT_WRITE_EVENTS); } } - pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents { - let reader_events = self.backlog.poll(mask, poller.as_deref_mut()); + pub(super) fn is_read_shutdown(&self) -> bool { + self.backlog.is_shutdown() + } - let writer_events = self.writer_pollee.poll_with(mask, poller, || { - if self.is_write_shutdown.load(Ordering::Relaxed) { - IoEvents::ERR - } else { - IoEvents::empty() - } - }); + pub(super) fn is_write_shutdown(&self) -> bool { + self.is_write_shutdown.load(Ordering::Relaxed) + } - combine_io_events(mask, reader_events, writer_events) + pub(super) fn check_io_events(&self) -> IoEvents { + self.backlog.check_io_events() } } @@ -131,8 +120,6 @@ impl BacklogTable { return None; } - // Note that the cached events can be correctly inherited from `Init`, so there is no need - // to explicitly call `Pollee::invalidate`. let new_backlog = Arc::new(Backlog::new(addr, pollee, backlog, is_shutdown)); backlog_sockets.insert(addr_key, new_backlog.clone()); @@ -206,26 +193,24 @@ impl Backlog { fn shutdown(&self) { *self.incoming_conns.lock() = None; - self.pollee.notify(IoEvents::HUP); + self.pollee.notify(SHUT_READ_EVENTS); self.wait_queue.wake_all(); } - fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee - .poll_with(mask, poller, || self.check_io_events()) + fn is_shutdown(&self) -> bool { + self.incoming_conns.lock().is_none() } fn check_io_events(&self) -> IoEvents { - let incoming_conns = self.incoming_conns.lock(); - - if let Some(conns) = &*incoming_conns { - if !conns.is_empty() { - IoEvents::IN - } else { - IoEvents::empty() - } + if self + .incoming_conns + .lock() + .as_ref() + .is_some_and(|conns| !conns.is_empty()) + { + IoEvents::IN } else { - IoEvents::HUP + IoEvents::empty() } } } @@ -234,6 +219,7 @@ impl Backlog { pub(super) fn push_incoming( &self, init: Init, + pollee: Pollee, ) -> core::result::Result { let mut locked_incoming_conns = self.incoming_conns.lock(); @@ -257,7 +243,7 @@ impl Backlog { )); } - let (client_conn, server_conn) = init.into_connected(self.addr.clone()); + let (client_conn, server_conn) = init.into_connected(self.addr.clone(), pollee); incoming_conns.push_back(server_conn); self.pollee.notify(IoEvents::IN); diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index bf332edd1..595fedd7f 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -11,7 +11,7 @@ use super::{ }; use crate::{ events::IoEvents, - fs::file_handle::FileLike, + fs::{file_handle::FileLike, utils::EndpointState}, net::socket::{ private::SocketPrivate, unix::UnixSocketAddr, @@ -19,12 +19,14 @@ use crate::{ Socket, }, prelude::*, - process::signal::{PollHandle, Pollable}, + process::signal::{PollHandle, Pollable, Pollee}, util::{MultiRead, MultiWrite}, }; pub struct UnixStreamSocket { state: RwMutex>, + + pollee: Pollee, is_nonblocking: AtomicBool, } @@ -32,13 +34,16 @@ impl UnixStreamSocket { pub(super) fn new_init(init: Init, is_nonblocking: bool) -> Arc { Arc::new(Self { state: RwMutex::new(Takeable::new(State::Init(init))), + pollee: Pollee::new(), is_nonblocking: AtomicBool::new(is_nonblocking), }) } pub(super) fn new_connected(connected: Connected, is_nonblocking: bool) -> Arc { + let cloned_pollee = connected.cloned_pollee(); Arc::new(Self { state: RwMutex::new(Takeable::new(State::Connected(connected))), + pollee: cloned_pollee, is_nonblocking: AtomicBool::new(is_nonblocking), }) } @@ -50,13 +55,69 @@ enum State { Connected(Connected), } +impl State { + pub(self) fn check_io_events(&self) -> IoEvents { + let mut events = IoEvents::empty(); + + let is_read_shutdown = self.is_read_shutdown(); + let is_write_shutdown = self.is_write_shutdown(); + + if is_read_shutdown { + // The socket is shut down in one direction: the remote socket has shut down for + // writing or the local socket has shut down for reading. + events |= IoEvents::RDHUP | IoEvents::IN; + + if is_write_shutdown { + // The socket is shut down in both directions. Neither reading nor writing is + // possible. + events |= IoEvents::HUP; + } + } + + if is_write_shutdown && !matches!(self, State::Listen(_)) { + // The socket is shut down in another direction: The remote socket has shut down for + // reading or the local socket has shut down for writing. + events |= IoEvents::OUT; + } + + events |= match self { + State::Init(init) => init.check_io_events(), + State::Listen(listener) => listener.check_io_events(), + State::Connected(connected) => connected.check_io_events(), + }; + + events + } + + fn is_read_shutdown(&self) -> bool { + match self { + State::Init(init) => init.is_read_shutdown(), + State::Listen(listener) => listener.is_read_shutdown(), + State::Connected(connected) => connected.is_read_shutdown(), + } + } + + fn is_write_shutdown(&self) -> bool { + match self { + State::Init(init) => init.is_write_shutdown(), + State::Listen(listener) => listener.is_write_shutdown(), + State::Connected(connected) => connected.is_write_shutdown(), + } + } +} + impl UnixStreamSocket { pub fn new(is_nonblocking: bool) -> Arc { Self::new_init(Init::new(), is_nonblocking) } pub fn new_pair(is_nonblocking: bool) -> (Arc, Arc) { - let (conn_a, conn_b) = Connected::new_pair(None, None, None, None); + let (conn_a, conn_b) = Connected::new_pair( + None, + None, + EndpointState::default(), + EndpointState::default(), + ); ( Self::new_connected(conn_a, is_nonblocking), Self::new_connected(conn_b, is_nonblocking), @@ -107,7 +168,7 @@ impl UnixStreamSocket { } }; - let connected = match backlog.push_incoming(init) { + let connected = match backlog.push_incoming(init, self.pollee.clone()) { Ok(connected) => connected, Err((err, init)) => return (State::Init(init), Err(err)), }; @@ -126,14 +187,14 @@ impl UnixStreamSocket { } } +pub(super) const SHUT_READ_EVENTS: IoEvents = + IoEvents::RDHUP.union(IoEvents::IN).union(IoEvents::HUP); +pub(super) const SHUT_WRITE_EVENTS: IoEvents = IoEvents::OUT.union(IoEvents::HUP); + impl Pollable for UnixStreamSocket { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - let inner = self.state.read(); - match inner.as_ref() { - State::Init(init) => init.poll(mask, poller), - State::Listen(listen) => listen.poll(mask, poller), - State::Connected(connected) => connected.poll(mask, poller), - } + self.pollee + .poll_with(mask, poller, || self.state.read().check_io_events()) } } @@ -200,7 +261,7 @@ impl Socket for UnixStreamSocket { } }; - let listener = match init.listen(backlog) { + let listener = match init.listen(backlog, self.pollee.clone()) { Ok(listener) => listener, Err((err, init)) => { return (State::Init(init), Err(err)); @@ -217,8 +278,8 @@ impl Socket for UnixStreamSocket { fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { match self.state.read().as_ref() { - State::Init(init) => init.shutdown(cmd), - State::Listen(listen) => listen.shutdown(cmd), + State::Init(init) => init.shutdown(cmd, &self.pollee), + State::Listen(listen) => listen.shutdown(cmd, &self.pollee), State::Connected(connected) => connected.shutdown(cmd), }