// SPDX-License-Identifier: MPL-2.0 use alloc::{ boxed::Box, sync::{Arc, Weak}, }; use core::{ ops::{Deref, DerefMut}, sync::atomic::{AtomicBool, AtomicU64, Ordering}, }; use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard}; use smoltcp::{ iface::Context, socket::{tcp::State, udp::UdpMetadata, PollAt}, time::Instant, wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, }; use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket}; use crate::iface::Iface; pub struct BoundSocket(Arc>); /// [`TcpSocket`] or [`UdpSocket`]. pub trait AnySocket { type RawSocket; /// Called by [`BoundSocket::new`]. fn new(socket: Box) -> Self; /// Called by [`BoundSocket::drop`]. fn on_drop(this: &Arc>) where Self: Sized; } pub type BoundTcpSocket = BoundSocket; pub type BoundUdpSocket = BoundSocket; /// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`]. pub struct BoundSocketInner { iface: Arc>, port: u16, socket: T, observer: RwLock>, next_poll_at_ms: AtomicU64, has_new_events: AtomicBool, } /// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`]. pub struct TcpSocket { socket: SpinLock, is_dead: AtomicBool, } struct RawTcpSocketExt { socket: Box, /// Whether the socket is in the background. /// /// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means /// that no more user events (like `send`/`recv`) can reach the socket, but it can be in a /// state of waiting for certain network events (e.g., remote FIN/ACK packets), so /// [`BoundSocketInner`] may still be alive for a while. in_background: bool, } impl Deref for RawTcpSocketExt { type Target = RawTcpSocket; fn deref(&self) -> &Self::Target { &self.socket } } impl DerefMut for RawTcpSocketExt { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.socket } } impl TcpSocket { fn lock(&self) -> SpinLockGuard { self.socket.lock() } /// Returns whether the TCP socket is dead. /// /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. fn is_dead(&self) -> bool { self.is_dead.load(Ordering::Relaxed) } /// Updates whether the TCP socket is dead. /// /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. /// /// This method must be called after handling network events. However, it is not necessary to /// call this method after handling non-closing user events, because the socket can never be /// dead if user events can reach the socket. fn update_dead(&self, socket: &RawTcpSocketExt) { if socket.in_background && socket.state() == smoltcp::socket::tcp::State::Closed { self.is_dead.store(true, Ordering::Relaxed); } } /// Sets the TCP socket in [`TimeWait`] state as dead. /// /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. /// /// [`TimeWait`]: smoltcp::socket::tcp::State::TimeWait fn set_dead_timewait(&self, socket: &RawTcpSocketExt) { debug_assert!( socket.in_background && socket.state() == smoltcp::socket::tcp::State::TimeWait ); self.is_dead.store(true, Ordering::Relaxed); } } impl AnySocket for TcpSocket { type RawSocket = RawTcpSocket; fn new(socket: Box) -> Self { let socket_ext = RawTcpSocketExt { socket, in_background: false, }; Self { socket: SpinLock::new(socket_ext), is_dead: AtomicBool::new(false), } } fn on_drop(this: &Arc>) { let mut socket = this.socket.lock(); socket.in_background = true; socket.close(); // A TCP socket may not be appropriate for immediate removal. We leave the removal decision // to the polling logic. this.update_next_poll_at_ms(PollAt::Now); this.socket.update_dead(&socket); } } /// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`]. type UdpSocket = SpinLock, LocalIrqDisabled>; impl AnySocket for UdpSocket { type RawSocket = RawUdpSocket; fn new(socket: Box) -> Self { Self::new(socket) } fn on_drop(this: &Arc>) { this.socket.lock().close(); // A UDP socket can be removed immediately. this.iface.common().remove_udp_socket(this); } } impl Drop for BoundSocket { fn drop(&mut self) { T::on_drop(&self.0); } } pub(crate) type BoundTcpSocketInner = BoundSocketInner; pub(crate) type BoundUdpSocketInner = BoundSocketInner; impl BoundSocket { pub(crate) fn new( iface: Arc>, port: u16, socket: Box, observer: Weak, ) -> Self { Self(Arc::new(BoundSocketInner { iface, port, socket: T::new(socket), observer: RwLock::new(observer), next_poll_at_ms: AtomicU64::new(u64::MAX), has_new_events: AtomicBool::new(false), })) } pub(crate) fn inner(&self) -> &Arc> { &self.0 } } impl BoundSocket { /// Sets the observer whose `on_events` will be called when certain iface events happen. After /// setting, the new observer will fire once immediately to avoid missing any events. /// /// If there is an existing observer, due to race conditions, this function does not guarantee /// that the old observer will never be called after the setting. Users should be aware of this /// and proactively handle the race conditions if necessary. pub fn set_observer(&self, new_observer: Weak) { *self.0.observer.write_irq_disabled() = new_observer; self.0.on_iface_events(); } /// Returns the observer. /// /// See also [`Self::set_observer`]. pub fn observer(&self) -> Weak { // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we // get the read lock. self.0.observer.read().clone() } pub fn local_endpoint(&self) -> Option { let ip_addr = { let ipv4_addr = self.0.iface.ipv4_addr()?; IpAddress::Ipv4(ipv4_addr) }; Some(IpEndpoint::new(ip_addr, self.0.port)) } pub fn iface(&self) -> &Arc> { &self.0.iface } } impl BoundTcpSocket { /// Connects to a remote endpoint. pub fn connect( &self, remote_endpoint: IpEndpoint, ) -> Result<(), smoltcp::socket::tcp::ConnectError> { let common = self.iface().common(); let mut iface = common.interface(); let mut socket = self.0.socket.lock(); let result = socket.connect(iface.context(), remote_endpoint, self.0.port); self.0 .update_next_poll_at_ms(socket.poll_at(iface.context())); result } /// Listens at a specified endpoint. pub fn listen( &self, local_endpoint: IpEndpoint, ) -> Result<(), smoltcp::socket::tcp::ListenError> { let mut socket = self.0.socket.lock(); socket.listen(local_endpoint) } pub fn send(&self, f: F) -> Result where F: FnOnce(&mut [u8]) -> (usize, R), { let mut socket = self.0.socket.lock(); let result = socket.send(f); self.0.update_next_poll_at_ms(PollAt::Now); result } pub fn recv(&self, f: F) -> Result where F: FnOnce(&mut [u8]) -> (usize, R), { let mut socket = self.0.socket.lock(); let result = socket.recv(f); self.0.update_next_poll_at_ms(PollAt::Now); result } pub fn close(&self) { let mut socket = self.0.socket.lock(); socket.close(); self.0.update_next_poll_at_ms(PollAt::Now); } /// Calls `f` with an immutable reference to the associated [`RawTcpSocket`]. // // NOTE: If a mutable reference is required, add a method above that correctly updates the next // polling time. pub fn raw_with(&self, f: F) -> R where F: FnOnce(&RawTcpSocket) -> R, { let socket = self.0.socket.lock(); f(&socket) } } impl BoundUdpSocket { /// Binds to a specified endpoint. pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> { let mut socket = self.0.socket.lock(); socket.bind(local_endpoint) } pub fn send( &self, size: usize, meta: impl Into, f: F, ) -> Result where F: FnOnce(&mut [u8]) -> R, { use smoltcp::socket::udp::SendError as SendErrorInner; use crate::errors::udp::SendError; let mut socket = self.0.socket.lock(); if size > socket.packet_send_capacity() { return Err(SendError::TooLarge); } let buffer = match socket.send(size, meta) { Ok(data) => data, Err(SendErrorInner::Unaddressable) => return Err(SendError::Unaddressable), Err(SendErrorInner::BufferFull) => return Err(SendError::BufferFull), }; let result = f(buffer); self.0.update_next_poll_at_ms(PollAt::Now); Ok(result) } pub fn recv(&self, f: F) -> Result where F: FnOnce(&[u8], UdpMetadata) -> R, { let mut socket = self.0.socket.lock(); let (data, meta) = socket.recv()?; let result = f(data, meta); self.0.update_next_poll_at_ms(PollAt::Now); Ok(result) } /// Calls `f` with an immutable reference to the associated [`RawUdpSocket`]. // // NOTE: If a mutable reference is required, add a method above that correctly updates the next // polling time. pub fn raw_with(&self, f: F) -> R where F: FnOnce(&RawUdpSocket) -> R, { let socket = self.0.socket.lock(); f(&socket) } } impl BoundSocketInner { pub(crate) fn has_new_events(&self) -> bool { self.has_new_events.load(Ordering::Relaxed) } pub(crate) fn on_iface_events(&self) { self.has_new_events.store(false, Ordering::Relaxed); // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we // get the read lock. let observer = Weak::upgrade(&*self.observer.read()); if let Some(inner) = observer { inner.on_events(); } } /// Returns the next polling time. /// /// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required /// before new network or user events. pub(crate) fn next_poll_at_ms(&self) -> u64 { self.next_poll_at_ms.load(Ordering::Relaxed) } /// Updates the next polling time according to `poll_at`. /// /// The update is typically needed after new network or user events have been handled, so this /// method also marks that there may be new events, so that the event observer provided by /// [`BoundSocket::set_observer`] can be notified later. fn update_next_poll_at_ms(&self, poll_at: PollAt) { self.has_new_events.store(true, Ordering::Relaxed); match poll_at { PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed), PollAt::Time(instant) => self .next_poll_at_ms .store(instant.total_millis() as u64, Ordering::Relaxed), PollAt::Ingress => self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed), } } } impl BoundSocketInner { pub(crate) fn port(&self) -> u16 { self.port } } impl BoundTcpSocketInner { /// Returns whether the TCP socket is dead. /// /// A TCP socket is considered dead if and only if the following two conditions are met: /// 1. The TCP connection is closed, so this socket cannot process any network events. /// 2. The socket handle [`BoundTcpSocket`] is dropped, which means that this /// [`BoundSocketInner`] is in background and no more user events can reach it. pub(crate) fn is_dead(&self) -> bool { self.socket.is_dead() } } impl BoundSocketInner { /// Returns whether an incoming packet _may_ be processed by the socket. /// /// The check is intended to be lock-free and fast, but may have false positives. pub(crate) fn can_process(&self, dst_port: u16) -> bool { self.port == dst_port } /// Returns whether the socket _may_ generate an outgoing packet. /// /// The check is intended to be lock-free and fast, but may have false positives. pub(crate) fn need_dispatch(&self, now: Instant) -> bool { now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed) } } #[derive(Debug, PartialEq, Eq, Clone)] pub(crate) enum TcpProcessResult { NotProcessed, Processed, ProcessedWithReply(IpRepr, TcpRepr<'static>), } impl BoundTcpSocketInner { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( &self, cx: &mut Context, ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> TcpProcessResult { let mut socket = self.socket.lock(); if !socket.accepts(cx, ip_repr, tcp_repr) { return TcpProcessResult::NotProcessed; } // If the socket is in the TimeWait state and a new packet arrives that is a SYN packet // without ack number, the TimeWait socket will be marked as dead, // and the packet will be passed on to any other listening sockets for processing. // // FIXME: Directly marking the TimeWait socket dead is not the correct approach. // In Linux, a TimeWait socket remains alive to handle "old duplicate segments". // If a TimeWait socket receives a new SYN packet, Linux will select a suitable // listening socket from the socket table to respond to that SYN request. // (https://elixir.bootlin.com/linux/v6.0.9/source/net/ipv4/tcp_ipv4.c#L2137) // Moreover, the Initial Sequence Number (ISN) will be set to prevent the TimeWait socket // from erroneously handling packets from the new connection. // (https://elixir.bootlin.com/linux/v6.0.9/source/net/ipv4/tcp_minisocks.c#L194) // Implementing such behavior is challenging with the current smoltcp APIs. if socket.state() == State::TimeWait && tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { self.socket.set_dead_timewait(&socket); return TcpProcessResult::NotProcessed; } let result = match socket.process(cx, ip_repr, tcp_repr) { None => TcpProcessResult::Processed, Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), }; self.update_next_poll_at_ms(socket.poll_at(cx)); self.socket.update_dead(&socket); result } /// Tries to generate an outgoing packet and dispatches the generated packet. pub(crate) fn dispatch( &self, cx: &mut Context, dispatch: D, ) -> Option<(IpRepr, TcpRepr<'static>)> where D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, { let mut socket = self.socket.lock(); let mut reply = None; socket .dispatch(cx, |cx, (ip_repr, tcp_repr)| { reply = dispatch(cx, &ip_repr, &tcp_repr); Ok::<(), ()>(()) }) .unwrap(); // `dispatch` can return a packet in response to the generated packet. If the socket // accepts the packet, we can process it directly. while let Some((ref ip_repr, ref tcp_repr)) = reply { if !socket.accepts(cx, ip_repr, tcp_repr) { break; } reply = socket.process(cx, ip_repr, tcp_repr); } self.update_next_poll_at_ms(socket.poll_at(cx)); self.socket.update_dead(&socket); reply } } impl BoundUdpSocketInner { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( &self, cx: &mut Context, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8], ) -> bool { let mut socket = self.socket.lock(); if !socket.accepts(cx, ip_repr, udp_repr) { return false; } socket.process( cx, smoltcp::phy::PacketMeta::default(), ip_repr, udp_repr, udp_payload, ); self.update_next_poll_at_ms(socket.poll_at(cx)); true } /// Tries to generate an outgoing packet and dispatches the generated packet. pub(crate) fn dispatch(&self, cx: &mut Context, dispatch: D) where D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]), { let mut socket = self.socket.lock(); socket .dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| { dispatch(cx, &ip_repr, &udp_repr, udp_payload); Ok::<(), ()>(()) }) .unwrap(); self.update_next_poll_at_ms(socket.poll_at(cx)); } }