asterinas/kernel/src/net/socket/netlink/table/mod.rs

329 lines
9.3 KiB
Rust

// SPDX-License-Identifier: MPL-2.0
use multicast::MulticastGroup;
pub(super) use multicast::MulticastMessage;
use spin::Once;
use super::{
addr::{GroupIdSet, MAX_GROUPS, NetlinkProtocolId, NetlinkSocketAddr, PortNum},
receiver::QueueableMessage,
};
use crate::{
net::socket::netlink::{
addr::UNSPECIFIED_PORT, kobject_uevent::UeventMessage, receiver::MessageReceiver,
route::RtnlMessage,
},
prelude::*,
util::random::getrandom,
};
mod multicast;
static NETLINK_SOCKET_TABLE: Once<NetlinkSocketTable> = Once::new();
/// All bound netlink sockets.
struct NetlinkSocketTable {
route: RwMutex<ProtocolSocketTable<RtnlMessage>>,
uevent: RwMutex<ProtocolSocketTable<UeventMessage>>,
}
impl NetlinkSocketTable {
fn new() -> Self {
Self {
route: RwMutex::new(ProtocolSocketTable::new()),
uevent: RwMutex::new(ProtocolSocketTable::new()),
}
}
}
pub trait SupportedNetlinkProtocol {
type Message: 'static + Send;
fn socket_table() -> &'static RwMutex<ProtocolSocketTable<Self::Message>>;
fn bind(
addr: &NetlinkSocketAddr,
receiver: MessageReceiver<Self::Message>,
) -> Result<BoundHandle<Self::Message>> {
let mut socket_table = Self::socket_table().write();
socket_table.bind(Self::socket_table(), addr, receiver)
}
fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()>
where
Self::Message: QueueableMessage,
{
let socket_table = Self::socket_table().read();
socket_table.unicast(dst_port, message)
}
#[cfg_attr(not(ktest), expect(dead_code))]
fn multicast(dst_groups: GroupIdSet, message: Self::Message) -> Result<()>
where
Self::Message: MulticastMessage,
{
let socket_table = Self::socket_table().read();
socket_table.multicast(dst_groups, message)
}
}
pub enum NetlinkRouteProtocol {}
impl SupportedNetlinkProtocol for NetlinkRouteProtocol {
type Message = RtnlMessage;
fn socket_table() -> &'static RwMutex<ProtocolSocketTable<Self::Message>> {
&NETLINK_SOCKET_TABLE.get().unwrap().route
}
}
pub enum NetlinkUeventProtocol {}
impl SupportedNetlinkProtocol for NetlinkUeventProtocol {
type Message = UeventMessage;
fn socket_table() -> &'static RwMutex<ProtocolSocketTable<Self::Message>> {
&NETLINK_SOCKET_TABLE.get().unwrap().uevent
}
}
/// Bound socket table of a single netlink protocol.
///
/// Each table can have bound sockets for unicast
/// and at most 32 groups for multicast.
pub struct ProtocolSocketTable<Message> {
unicast_sockets: BTreeMap<PortNum, MessageReceiver<Message>>,
multicast_groups: Box<[MulticastGroup]>,
}
impl<Message: 'static> ProtocolSocketTable<Message> {
/// Creates a new table.
fn new() -> Self {
let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect();
Self {
unicast_sockets: BTreeMap::new(),
multicast_groups,
}
}
/// Binds a socket to the table.
/// Returns the bound handle.
///
/// The socket will be bound to a port specified by `addr.port()`.
/// If `addr.port()` is zero, the kernel will assign a port,
/// typically corresponding to the process ID of the current process.
/// If the assigned port is already in use,
/// this function will try to allocate a random unused port.
///
/// Additionally, this socket can join one or more multicast groups,
/// as specified in `addr.groups()`.
fn bind(
&mut self,
socket_table: &'static RwMutex<ProtocolSocketTable<Message>>,
addr: &NetlinkSocketAddr,
receiver: MessageReceiver<Message>,
) -> Result<BoundHandle<Message>> {
let port = if addr.port() != UNSPECIFIED_PORT {
addr.port()
} else {
let mut random_port = current!().pid();
while random_port == UNSPECIFIED_PORT || self.unicast_sockets.contains_key(&random_port)
{
getrandom(random_port.as_mut_bytes());
}
random_port
};
if self.unicast_sockets.contains_key(&port) {
return_errno_with_message!(Errno::EADDRINUSE, "the netlink port is already in use");
}
self.unicast_sockets.insert(port, receiver);
for group_id in addr.groups().ids_iter() {
let group = &mut self.multicast_groups[group_id as usize];
group.add_member(port);
}
Ok(BoundHandle::new(socket_table, port, addr.groups()))
}
fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()>
where
Message: QueueableMessage,
{
let Some(receiver) = self.unicast_sockets.get(&dst_port) else {
// FIXME: Should we return error here?
return Ok(());
};
receiver.enqueue_message(message);
Ok(())
}
fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<()>
where
Message: MulticastMessage,
{
for group in dst_groups.ids_iter() {
let Some(group) = self.multicast_groups.get(group as usize) else {
continue;
};
for port_num in group.members() {
let Some(receiver) = self.unicast_sockets.get(port_num) else {
continue;
};
receiver.enqueue_message(message.clone());
}
}
Ok(())
}
}
/// A bound netlink socket address.
///
/// When dropping a `BoundHandle`,
/// the port will be automatically released.
pub struct BoundHandle<Message: 'static> {
socket_table: &'static RwMutex<ProtocolSocketTable<Message>>,
port: PortNum,
groups: GroupIdSet,
}
impl<Message: 'static> BoundHandle<Message> {
fn new(
socket_table: &'static RwMutex<ProtocolSocketTable<Message>>,
port: PortNum,
groups: GroupIdSet,
) -> Self {
debug_assert_ne!(port, UNSPECIFIED_PORT);
Self {
socket_table,
port,
groups,
}
}
pub(super) const fn port(&self) -> PortNum {
self.port
}
pub(super) const fn addr(&self) -> NetlinkSocketAddr {
NetlinkSocketAddr::new(self.port, self.groups)
}
pub(super) fn add_groups(&mut self, groups: GroupIdSet) {
let mut protocol_sockets = self.socket_table.write();
for group_id in groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.add_member(self.port);
}
self.groups.add_groups(groups);
}
pub(super) fn drop_groups(&mut self, groups: GroupIdSet) {
let mut protocol_sockets = self.socket_table.write();
for group_id in groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.remove_member(self.port);
}
self.groups.drop_groups(groups);
}
pub(super) fn bind_groups(&mut self, groups: GroupIdSet) {
let mut protocol_sockets = self.socket_table.write();
for group_id in self.groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.remove_member(self.port);
}
for group_id in groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.add_member(self.port);
}
self.groups = groups;
}
}
impl<Message: 'static> Drop for BoundHandle<Message> {
fn drop(&mut self) {
let mut protocol_sockets = self.socket_table.write();
protocol_sockets.unicast_sockets.remove(&self.port);
for group_id in self.groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.remove_member(self.port);
}
}
}
pub(super) fn init() {
NETLINK_SOCKET_TABLE.call_once(NetlinkSocketTable::new);
}
/// Returns whether the `protocol` is valid.
pub fn is_valid_protocol(protocol: NetlinkProtocolId) -> bool {
protocol < MAX_ALLOWED_PROTOCOL_ID
}
/// Netlink protocols that are assigned for specific usage.
///
/// Reference: <https://elixir.bootlin.com/linux/v6.0.9/source/include/uapi/linux/netlink.h#L9>.
#[expect(non_camel_case_types)]
#[expect(clippy::upper_case_acronyms)]
#[repr(u32)]
#[derive(Debug, Clone, Copy, TryFromInt)]
pub enum StandardNetlinkProtocol {
/// Routing/device hook
ROUTE = 0,
/// Unused number
UNUSED = 1,
/// Reserved for user mode socket protocols
USERSOCK = 2,
/// Unused number, formerly ip_queue
FIREWALL = 3,
/// Socket monitoring
SOCK_DIAG = 4,
/// Netfilter/iptables ULOG
NFLOG = 5,
/// IPsec
XFRM = 6,
/// SELinux event notifications
SELINUX = 7,
/// Open-iSCSI
ISCSI = 8,
/// Auditing
AUDIT = 9,
FIB_LOOKUP = 10,
CONNECTOR = 11,
/// Netfilter subsystem
NETFILTER = 12,
IP6_FW = 13,
/// DECnet routing messages
DNRTMSG = 14,
/// Kernel messages to userspace
KOBJECT_UEVENT = 15,
GENERIC = 16,
/// Leave room for NETLINK_DM (DM Events)
/// SCSI Transports
SCSITRANSPORT = 18,
ECRYPTFS = 19,
RDMA = 20,
/// Crypto layer
CRYPTO = 21,
/// SMC monitoring
SMC = 22,
}
const MAX_ALLOWED_PROTOCOL_ID: NetlinkProtocolId = 32;