538 lines
18 KiB
Rust
538 lines
18 KiB
Rust
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
use core::sync::atomic::{AtomicBool, Ordering};
|
|
|
|
use aster_rights::ReadDupOp;
|
|
use takeable::Takeable;
|
|
|
|
use super::{
|
|
connected::Connected,
|
|
init::Init,
|
|
listener::{get_backlog, Backlog, Listener},
|
|
};
|
|
use crate::{
|
|
events::IoEvents,
|
|
fs::{
|
|
file_handle::FileLike,
|
|
utils::{EndpointState, Inode},
|
|
},
|
|
net::socket::{
|
|
new_pseudo_inode,
|
|
options::{
|
|
macros::sock_option_mut, Error as SocketError, PeerCred, PeerGroups, SocketOption,
|
|
},
|
|
private::SocketPrivate,
|
|
unix::{cred::SocketCred, ctrl_msg::AuxiliaryData, CUserCred, UnixSocketAddr},
|
|
util::{
|
|
options::{GetSocketLevelOption, SetSocketLevelOption, SocketOptionSet},
|
|
ControlMessage, MessageHeader, SendRecvFlags, SockShutdownCmd, SocketAddr,
|
|
},
|
|
Socket,
|
|
},
|
|
prelude::*,
|
|
process::{
|
|
signal::{PollHandle, Pollable, Pollee},
|
|
Gid,
|
|
},
|
|
util::{MultiRead, MultiWrite},
|
|
};
|
|
|
|
pub struct UnixStreamSocket {
|
|
// Lock order: `state` first, `options` second
|
|
state: RwMutex<Takeable<State>>,
|
|
options: RwLock<OptionSet>,
|
|
|
|
pollee: Pollee,
|
|
is_nonblocking: AtomicBool,
|
|
|
|
is_seqpacket: bool,
|
|
pseudo_inode: Arc<dyn Inode>,
|
|
}
|
|
|
|
enum State {
|
|
Init(Init),
|
|
Listen(Listener),
|
|
Connected(Connected),
|
|
}
|
|
|
|
impl State {
|
|
pub(self) fn check_io_events(&self) -> IoEvents {
|
|
let mut events = IoEvents::empty();
|
|
|
|
let is_read_shutdown = self.is_read_shutdown();
|
|
let is_write_shutdown = self.is_write_shutdown();
|
|
|
|
if is_read_shutdown {
|
|
// The socket is shut down in one direction: the remote socket has shut down for
|
|
// writing or the local socket has shut down for reading.
|
|
events |= IoEvents::RDHUP | IoEvents::IN;
|
|
|
|
if is_write_shutdown {
|
|
// The socket is shut down in both directions. Neither reading nor writing is
|
|
// possible.
|
|
events |= IoEvents::HUP;
|
|
}
|
|
}
|
|
|
|
if is_write_shutdown && !matches!(self, State::Listen(_)) {
|
|
// The socket is shut down in another direction: The remote socket has shut down for
|
|
// reading or the local socket has shut down for writing.
|
|
events |= IoEvents::OUT;
|
|
}
|
|
|
|
events |= match self {
|
|
State::Init(init) => init.check_io_events(),
|
|
State::Listen(listener) => listener.check_io_events(),
|
|
State::Connected(connected) => connected.check_io_events(),
|
|
};
|
|
|
|
events
|
|
}
|
|
|
|
fn is_read_shutdown(&self) -> bool {
|
|
match self {
|
|
State::Init(init) => init.is_read_shutdown(),
|
|
State::Listen(listener) => listener.is_read_shutdown(),
|
|
State::Connected(connected) => connected.is_read_shutdown(),
|
|
}
|
|
}
|
|
|
|
fn is_write_shutdown(&self) -> bool {
|
|
match self {
|
|
State::Init(init) => init.is_write_shutdown(),
|
|
State::Listen(listener) => listener.is_write_shutdown(),
|
|
State::Connected(connected) => connected.is_write_shutdown(),
|
|
}
|
|
}
|
|
|
|
pub(self) fn peer_cred(&self) -> Option<CUserCred> {
|
|
match self {
|
|
Self::Init(_) => None,
|
|
Self::Listen(listener) => Some(listener.cred().to_effective_c_cred()),
|
|
Self::Connected(connected) => Some(connected.peer_cred().to_effective_c_cred()),
|
|
}
|
|
}
|
|
|
|
pub(self) fn peer_groups(&self) -> Result<Arc<[Gid]>> {
|
|
match self {
|
|
State::Init(_) => {
|
|
return_errno_with_message!(Errno::ENODATA, "the socket does not have peer groups")
|
|
}
|
|
State::Listen(listener) => Ok(listener.cred().groups()),
|
|
State::Connected(connected) => Ok(connected.peer_cred().groups()),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub(super) struct OptionSet {
|
|
socket: SocketOptionSet,
|
|
}
|
|
|
|
impl OptionSet {
|
|
pub(super) fn new() -> Self {
|
|
Self {
|
|
socket: SocketOptionSet::new_unix_stream(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl UnixStreamSocket {
|
|
pub fn new(is_nonblocking: bool, is_seqpacket: bool) -> Arc<Self> {
|
|
Self::new_init(Init::new(), is_nonblocking, is_seqpacket)
|
|
}
|
|
|
|
fn new_init(init: Init, is_nonblocking: bool, is_seqpacket: bool) -> Arc<Self> {
|
|
Arc::new(Self {
|
|
state: RwMutex::new(Takeable::new(State::Init(init))),
|
|
options: RwLock::new(OptionSet::new()),
|
|
pollee: Pollee::new(),
|
|
is_nonblocking: AtomicBool::new(is_nonblocking),
|
|
is_seqpacket,
|
|
pseudo_inode: new_pseudo_inode(),
|
|
})
|
|
}
|
|
|
|
pub fn new_pair(is_nonblocking: bool, is_seqpacket: bool) -> (Arc<Self>, Arc<Self>) {
|
|
let cred = SocketCred::<ReadDupOp>::new_current();
|
|
let options = OptionSet::new();
|
|
|
|
let (conn_a, conn_b) = Connected::new_pair(
|
|
None,
|
|
None,
|
|
EndpointState::default(),
|
|
EndpointState::default(),
|
|
cred.dup().restrict(),
|
|
cred.restrict(),
|
|
&options.socket,
|
|
);
|
|
(
|
|
Self::new_connected(conn_a, options, is_nonblocking, is_seqpacket),
|
|
Self::new_connected(conn_b, OptionSet::new(), is_nonblocking, is_seqpacket),
|
|
)
|
|
}
|
|
|
|
pub(super) fn new_connected(
|
|
connected: Connected,
|
|
options: OptionSet,
|
|
is_nonblocking: bool,
|
|
is_seqpacket: bool,
|
|
) -> Arc<Self> {
|
|
let cloned_pollee = connected.cloned_pollee();
|
|
Arc::new(Self {
|
|
state: RwMutex::new(Takeable::new(State::Connected(connected))),
|
|
options: RwLock::new(options),
|
|
pollee: cloned_pollee,
|
|
is_nonblocking: AtomicBool::new(is_nonblocking),
|
|
is_seqpacket,
|
|
pseudo_inode: new_pseudo_inode(),
|
|
})
|
|
}
|
|
|
|
fn try_send(
|
|
&self,
|
|
buf: &mut dyn MultiRead,
|
|
aux_data: &mut AuxiliaryData,
|
|
_flags: SendRecvFlags,
|
|
) -> Result<usize> {
|
|
match self.state.read().as_ref() {
|
|
State::Connected(connected) => connected.try_write(buf, aux_data, self.is_seqpacket),
|
|
State::Init(_) | State::Listen(_) => {
|
|
return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected")
|
|
}
|
|
}
|
|
}
|
|
|
|
fn try_recv(
|
|
&self,
|
|
buf: &mut dyn MultiWrite,
|
|
_flags: SendRecvFlags,
|
|
) -> Result<(usize, Vec<ControlMessage>)> {
|
|
match self.state.read().as_ref() {
|
|
State::Connected(connected) => connected.try_read(buf, self.is_seqpacket),
|
|
State::Init(_) | State::Listen(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is not connected")
|
|
}
|
|
}
|
|
}
|
|
|
|
fn try_connect(&self, backlog: &Arc<Backlog>) -> Result<()> {
|
|
let mut state = self.state.write();
|
|
let options = self.options.read();
|
|
|
|
state.borrow_result(|owned_state| {
|
|
let init = match owned_state {
|
|
State::Init(init) => init,
|
|
State::Listen(listener) => {
|
|
return (
|
|
State::Listen(listener),
|
|
Err(Error::with_message(
|
|
Errno::EINVAL,
|
|
"the socket is listening",
|
|
)),
|
|
);
|
|
}
|
|
State::Connected(connected) => {
|
|
return (
|
|
State::Connected(connected),
|
|
Err(Error::with_message(
|
|
Errno::EISCONN,
|
|
"the socket is connected",
|
|
)),
|
|
);
|
|
}
|
|
};
|
|
|
|
let connected = match backlog.push_incoming(
|
|
init,
|
|
self.pollee.clone(),
|
|
&options.socket,
|
|
self.is_seqpacket,
|
|
) {
|
|
Ok(connected) => connected,
|
|
Err((err, init)) => return (State::Init(init), Err(err)),
|
|
};
|
|
|
|
(State::Connected(connected), Ok(()))
|
|
})
|
|
}
|
|
|
|
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
|
match self.state.read().as_ref() {
|
|
State::Listen(listen) => listen.try_accept(self.is_seqpacket) as _,
|
|
State::Init(_) | State::Connected(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is not listening")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(super) const SHUT_READ_EVENTS: IoEvents =
|
|
IoEvents::RDHUP.union(IoEvents::IN).union(IoEvents::HUP);
|
|
pub(super) const SHUT_WRITE_EVENTS: IoEvents = IoEvents::OUT.union(IoEvents::HUP);
|
|
|
|
impl Pollable for UnixStreamSocket {
|
|
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
|
|
self.pollee
|
|
.poll_with(mask, poller, || self.state.read().check_io_events())
|
|
}
|
|
}
|
|
|
|
impl SocketPrivate for UnixStreamSocket {
|
|
fn is_nonblocking(&self) -> bool {
|
|
self.is_nonblocking.load(Ordering::Relaxed)
|
|
}
|
|
|
|
fn set_nonblocking(&self, nonblocking: bool) {
|
|
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
impl Socket for UnixStreamSocket {
|
|
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
|
|
let addr = UnixSocketAddr::try_from(socket_addr)?;
|
|
|
|
match self.state.write().as_mut() {
|
|
State::Init(init) => init.bind(addr),
|
|
State::Connected(connected) => connected.bind(addr),
|
|
State::Listen(_) => {
|
|
// Listening sockets are always already bound.
|
|
addr.bind_unnamed()
|
|
}
|
|
}
|
|
}
|
|
|
|
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
|
|
let remote_addr = UnixSocketAddr::try_from(socket_addr)?.connect()?;
|
|
let backlog = get_backlog(&remote_addr)?;
|
|
|
|
if self.is_nonblocking() {
|
|
self.try_connect(&backlog)
|
|
} else {
|
|
backlog.block_connect(|| self.try_connect(&backlog))
|
|
}
|
|
}
|
|
|
|
fn listen(&self, backlog: usize) -> Result<()> {
|
|
const SOMAXCONN: usize = 4096;
|
|
|
|
// Linux allows a maximum of `backlog + 1` sockets in the backlog queue. Although this
|
|
// seems to be mostly an implementation detail, we follow the exact Linux behavior to
|
|
// ensure that our regression tests pass with the Linux kernel.
|
|
let backlog = backlog.saturating_add(1).min(SOMAXCONN);
|
|
|
|
let mut state = self.state.write();
|
|
|
|
state.borrow_result(|owned_state| {
|
|
let init = match owned_state {
|
|
State::Init(init) => init,
|
|
State::Listen(listener) => {
|
|
listener.listen(backlog);
|
|
return (State::Listen(listener), Ok(()));
|
|
}
|
|
State::Connected(connected) => {
|
|
return (
|
|
State::Connected(connected),
|
|
Err(Error::with_message(
|
|
Errno::EINVAL,
|
|
"the socket is connected",
|
|
)),
|
|
);
|
|
}
|
|
};
|
|
|
|
let listener = match init.listen(backlog, self.pollee.clone(), self.is_seqpacket) {
|
|
Ok(listener) => listener,
|
|
Err((err, init)) => {
|
|
return (State::Init(init), Err(err));
|
|
}
|
|
};
|
|
|
|
(State::Listen(listener), Ok(()))
|
|
})
|
|
}
|
|
|
|
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
|
self.block_on(IoEvents::IN, || self.try_accept())
|
|
}
|
|
|
|
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
|
match self.state.read().as_ref() {
|
|
State::Init(init) => init.shutdown(cmd, &self.pollee),
|
|
State::Listen(listen) => listen.shutdown(cmd, &self.pollee),
|
|
State::Connected(connected) => connected.shutdown(cmd),
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn addr(&self) -> Result<SocketAddr> {
|
|
let addr = match self.state.read().as_ref() {
|
|
State::Init(init) => init.addr().cloned(),
|
|
State::Listen(listen) => Some(listen.addr().clone()),
|
|
State::Connected(connected) => connected.addr().cloned(),
|
|
};
|
|
|
|
Ok(addr.into())
|
|
}
|
|
|
|
fn peer_addr(&self) -> Result<SocketAddr> {
|
|
let peer_addr = match self.state.read().as_ref() {
|
|
State::Connected(connected) => connected.peer_addr(),
|
|
State::Init(_) | State::Listen(_) => {
|
|
return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected")
|
|
}
|
|
};
|
|
|
|
Ok(peer_addr.into())
|
|
}
|
|
|
|
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
|
|
sock_option_mut!(match option {
|
|
socket_errors @ SocketError => {
|
|
// TODO: Support socket errors for UNIX sockets
|
|
socket_errors.set(None);
|
|
return Ok(());
|
|
}
|
|
_ => (),
|
|
});
|
|
|
|
let state = self.state.read();
|
|
let options = self.options.read();
|
|
|
|
// Deal with UNIX-socket-specific socket-level options
|
|
match do_unix_getsockopt(option, state.as_ref()) {
|
|
Err(err) if err.error() == Errno::ENOPROTOOPT => (),
|
|
res => return res,
|
|
}
|
|
|
|
// Deal with socket-level options
|
|
match options.socket.get_option(option, state.as_ref()) {
|
|
Err(err) if err.error() == Errno::ENOPROTOOPT => (),
|
|
res => return res,
|
|
}
|
|
|
|
// TODO: Deal with socket options from other levels
|
|
warn!("only socket-level options are supported");
|
|
|
|
return_errno_with_message!(Errno::ENOPROTOOPT, "the socket option to get is unknown")
|
|
}
|
|
|
|
fn set_option(&self, option: &dyn SocketOption) -> Result<()> {
|
|
let state = self.state.read();
|
|
let mut options = self.options.write();
|
|
|
|
match options.socket.set_option(option, state.as_ref()) {
|
|
Err(err) if err.error() == Errno::ENOPROTOOPT => {
|
|
// TODO: Deal with socket options from other levels
|
|
warn!("only socket-level options are supported");
|
|
return_errno_with_message!(
|
|
Errno::ENOPROTOOPT,
|
|
"the socket option to get is unknown"
|
|
)
|
|
}
|
|
res => res.map(|_need_iface_poll| ()),
|
|
}
|
|
}
|
|
|
|
fn sendmsg(
|
|
&self,
|
|
reader: &mut dyn MultiRead,
|
|
message_header: MessageHeader,
|
|
flags: SendRecvFlags,
|
|
) -> Result<usize> {
|
|
// TODO: Deal with flags
|
|
if !flags.is_all_supported() {
|
|
warn!("unsupported flags: {:?}", flags);
|
|
}
|
|
|
|
let MessageHeader {
|
|
control_messages,
|
|
addr,
|
|
} = message_header;
|
|
|
|
// According to the Linux man pages, `EISCONN` _may_ be returned when the destination
|
|
// address is specified for a connection-mode socket. In practice, `sendmsg` on UNIX stream
|
|
// sockets will fail due to that. We follow the same behavior as the Linux implementation.
|
|
if !self.is_seqpacket && addr.is_some() {
|
|
match self.state.read().as_ref() {
|
|
State::Init(_) | State::Listen(_) => return_errno_with_message!(
|
|
Errno::EOPNOTSUPP,
|
|
"sending to a specific address is not allowed on UNIX stream sockets"
|
|
),
|
|
State::Connected(_) => return_errno_with_message!(
|
|
Errno::EISCONN,
|
|
"sending to a specific address is not allowed on UNIX stream sockets"
|
|
),
|
|
}
|
|
}
|
|
let mut auxiliary_data = AuxiliaryData::from_control(control_messages)?;
|
|
|
|
self.block_on(IoEvents::OUT, || {
|
|
self.try_send(reader, &mut auxiliary_data, flags)
|
|
})
|
|
}
|
|
|
|
fn recvmsg(
|
|
&self,
|
|
writer: &mut dyn MultiWrite,
|
|
flags: SendRecvFlags,
|
|
) -> Result<(usize, MessageHeader)> {
|
|
// TODO: Deal with flags
|
|
if !flags.is_all_supported() {
|
|
warn!("unsupported flags: {:?}", flags);
|
|
}
|
|
|
|
let (received_bytes, control_messages) =
|
|
self.block_on(IoEvents::IN, || self.try_recv(writer, flags))?;
|
|
|
|
let message_header = MessageHeader::new(None, control_messages);
|
|
|
|
Ok((received_bytes, message_header))
|
|
}
|
|
|
|
fn pseudo_inode(&self) -> &Arc<dyn Inode> {
|
|
&self.pseudo_inode
|
|
}
|
|
}
|
|
|
|
fn do_unix_getsockopt(option: &mut dyn SocketOption, state: &State) -> Result<()> {
|
|
sock_option_mut!(match option {
|
|
socket_peer_cred @ PeerCred => {
|
|
let peer_cred = state.peer_cred().unwrap_or_else(CUserCred::new_invalid);
|
|
socket_peer_cred.set(peer_cred);
|
|
}
|
|
socket_peer_groups @ PeerGroups => {
|
|
let groups = state.peer_groups()?;
|
|
socket_peer_groups.set(groups);
|
|
}
|
|
_ => return_errno_with_message!(
|
|
Errno::ENOPROTOOPT,
|
|
"the socket option to get is not UNIX-socket-specific"
|
|
),
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
impl GetSocketLevelOption for State {
|
|
fn is_listening(&self) -> bool {
|
|
matches!(self, Self::Listen(_))
|
|
}
|
|
}
|
|
|
|
impl SetSocketLevelOption for State {
|
|
fn set_pass_cred(&self, pass_cred: bool) {
|
|
match self {
|
|
Self::Init(_) => {
|
|
// TODO: According to the Linux man pages, "When this option is set and the socket
|
|
// is not yet connected, a unique name in the abstract namespace will be generated
|
|
// automatically." See <https://man7.org/linux/man-pages/man7/unix.7.html> for
|
|
// details.
|
|
}
|
|
Self::Listen(_) => {}
|
|
Self::Connected(connected) => connected.set_pass_cred(pass_cred),
|
|
}
|
|
}
|
|
}
|