From c289f96d2322ccbf2742cebf9dd416d6a44d1e9d Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Wed, 3 Sep 2025 22:59:35 +0800 Subject: [PATCH] Report `ENOBUFS` if netlink messages overrun --- kernel/src/net/socket/netlink/common/bound.rs | 3 + .../socket/netlink/kobject_uevent/bound.rs | 23 ++-- .../netlink/kobject_uevent/message/mod.rs | 15 ++- kernel/src/net/socket/netlink/message/mod.rs | 11 +- kernel/src/net/socket/netlink/receiver.rs | 108 ++++++++++++++---- kernel/src/net/socket/netlink/route/bound.rs | 25 ++-- kernel/src/net/socket/netlink/table/mod.rs | 22 ++-- .../src/net/socket/netlink/table/multicast.rs | 7 +- test/src/apps/network/netlink_route.c | 93 +++++++++++---- 9 files changed, 210 insertions(+), 97 deletions(-) diff --git a/kernel/src/net/socket/netlink/common/bound.rs b/kernel/src/net/socket/netlink/common/bound.rs index 575750afd..24175c017 100644 --- a/kernel/src/net/socket/netlink/common/bound.rs +++ b/kernel/src/net/socket/netlink/common/bound.rs @@ -50,6 +50,9 @@ impl BoundNetlink { if !receive_queue.is_empty() { events |= IoEvents::IN; } + if receive_queue.has_errors() { + events |= IoEvents::ERR; + } events } diff --git a/kernel/src/net/socket/netlink/kobject_uevent/bound.rs b/kernel/src/net/socket/netlink/kobject_uevent/bound.rs index cc7ba18bf..92d6109c4 100644 --- a/kernel/src/net/socket/netlink/kobject_uevent/bound.rs +++ b/kernel/src/net/socket/netlink/kobject_uevent/bound.rs @@ -69,24 +69,15 @@ impl datagram_common::Bound for BoundNetlinkUevent { let mut receive_queue = self.receive_queue.lock(); - let Some(response) = receive_queue.peek() else { - return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"); - }; + receive_queue.dequeue_if(|response, response_len| { + let len = response_len.min(writer.sum_lens()); + response.write_to(writer)?; - let len = { - let max_len = writer.sum_lens(); - response.total_len().min(max_len) - }; + let remote = *response.src_addr(); - response.write_to(writer)?; - - let remote = *response.src_addr(); - - if !flags.contains(SendRecvFlags::MSG_PEEK) { - receive_queue.dequeue().unwrap(); - } - - Ok((len, remote)) + let should_dequeue = !flags.contains(SendRecvFlags::MSG_PEEK); + Ok((should_dequeue, (len, remote))) + }) } fn check_io_events(&self) -> IoEvents { diff --git a/kernel/src/net/socket/netlink/kobject_uevent/message/mod.rs b/kernel/src/net/socket/netlink/kobject_uevent/message/mod.rs index fc79ba739..7f94a77b3 100644 --- a/kernel/src/net/socket/netlink/kobject_uevent/message/mod.rs +++ b/kernel/src/net/socket/netlink/kobject_uevent/message/mod.rs @@ -5,7 +5,9 @@ use uevent::Uevent; use crate::{ - net::socket::netlink::{table::MulticastMessage, NetlinkSocketAddr}, + net::socket::netlink::{ + receiver::QueueableMessage, table::MulticastMessage, NetlinkSocketAddr, + }, prelude::*, util::MultiWrite, }; @@ -39,11 +41,6 @@ impl UeventMessage { &self.src_addr } - /// Returns the total length of the uevent. - pub(super) fn total_len(&self) -> usize { - self.uevent.len() - } - /// Writes the uevent to the given `writer`. pub(super) fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> { let _nbytes = writer.write(&mut VmReader::from(self.uevent.as_bytes()))?; @@ -53,4 +50,10 @@ impl UeventMessage { } } +impl QueueableMessage for UeventMessage { + fn total_len(&self) -> usize { + self.uevent.len() + } +} + impl MulticastMessage for UeventMessage {} diff --git a/kernel/src/net/socket/netlink/message/mod.rs b/kernel/src/net/socket/netlink/message/mod.rs index dd82e9a2c..dce653c75 100644 --- a/kernel/src/net/socket/netlink/message/mod.rs +++ b/kernel/src/net/socket/netlink/message/mod.rs @@ -16,6 +16,7 @@ pub(super) use segment::{ CSegmentType, SegmentBody, }; +use super::receiver::QueueableMessage; use crate::{ prelude::*, util::{MultiRead, MultiWrite}, @@ -26,11 +27,11 @@ use crate::{ /// A netlink message can be transmitted to and from user space using a single send/receive syscall. /// It consists of one or more [`ProtocolSegment`]s. #[derive(Debug)] -pub struct Message { +pub struct Message { segments: Vec, } -impl Message { +impl Message { pub(super) const fn new(segments: Vec) -> Self { Self { segments } } @@ -42,7 +43,9 @@ impl Message { pub(super) fn segments_mut(&mut self) -> &mut [T] { &mut self.segments } +} +impl Message { pub(super) fn read_from(reader: &mut dyn MultiRead) -> Result { // FIXME: Does a request contain only one segment? We need to investigate further. let segments = { @@ -60,8 +63,10 @@ impl Message { Ok(()) } +} - pub(super) fn total_len(&self) -> usize { +impl QueueableMessage for Message { + fn total_len(&self) -> usize { self.segments .iter() .map(|segment| segment.header().len as usize) diff --git a/kernel/src/net/socket/netlink/receiver.rs b/kernel/src/net/socket/netlink/receiver.rs index 85e6688b3..1c63b95cd 100644 --- a/kernel/src/net/socket/netlink/receiver.rs +++ b/kernel/src/net/socket/netlink/receiver.rs @@ -7,11 +7,20 @@ pub struct MessageReceiver { pollee: Pollee, } -pub(super) struct MessageQueue(VecDeque); +pub(super) struct MessageQueue { + messages: VecDeque, + total_length: usize, + error: Option, +} impl MessageQueue { + /// Creates a pair of a [`MessageQueue`] and a [`MessageReceiver`]. pub(super) fn new_pair(pollee: Pollee) -> (Arc>, MessageReceiver) { - let queue = Arc::new(Mutex::new(Self(VecDeque::new()))); + let queue = Arc::new(Mutex::new(Self { + messages: VecDeque::new(), + total_length: 0, + error: None, + })); let receiver = MessageReceiver { message_queue: queue.clone(), pollee, @@ -19,31 +28,88 @@ impl MessageQueue { (queue, receiver) } + /// Returns whether the message queue is empty. pub(super) fn is_empty(&self) -> bool { - self.0.is_empty() + self.messages.is_empty() } - pub(super) fn peek(&self) -> Option<&Message> { - self.0.front() - } - - pub(super) fn dequeue(&mut self) -> Option { - self.0.pop_front() - } - - pub(self) fn enqueue(&mut self, message: Message) -> Result<()> { - // FIXME: We should verify the socket buffer length to ensure - // that adding the message doesn't exceed the buffer capacity. - self.0.push_back(message); - Ok(()) + /// Returns whether the message queue contains errors. + /// + /// Currently, the message queue contains errors only if the queue is full but the kernel still + /// wants to enqueue new messages. + pub(super) fn has_errors(&self) -> bool { + self.error.is_some() } } -impl MessageReceiver { - pub(super) fn enqueue_message(&self, message: Message) -> Result<()> { - self.message_queue.lock().enqueue(message)?; - self.pollee.notify(IoEvents::IN); +/// Messages that fit into the [`MessageQueue`]. +pub trait QueueableMessage { + /// Counts and returns the length of the message. + fn total_len(&self) -> usize; +} - Ok(()) +impl MessageQueue { + /// Dequeues a message if executing the closure returns `Ok((true, _))`. + /// + /// The closure will be executed with a reference to the message that is ready to be dequeued + /// and the length of the message. + /// + /// If the queue contains errors (see [`Self::has_errors`]), the error will be cleared and + /// returned. In this case, the closure will not be executed. + pub(super) fn dequeue_if(&mut self, f: F) -> Result + where + F: FnOnce(&Message, usize) -> Result<(bool, R)>, + { + if let Some(error) = self.error.take() { + return Err(error); + } + + let Some(message) = self.messages.front() else { + return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"); + }; + + let length = message.total_len(); + let (should_pop, result) = f(message, length)?; + if should_pop { + self.messages.pop_front().unwrap(); + self.total_length -= length; + } + + Ok(result) + } + + /// Tries to enqueue a new message. Returns `false` if the buffer is full. + #[must_use] + pub(self) fn enqueue(&mut self, message: Message) -> bool { + let length = message.total_len(); + + // Currently, we don't support sending netlink messages between user spaces, so only the + // kernel can enqueue new messages. If the kernel fails to enqueue a new message, `ENOBUFS` + // will be returned when userspace calls `recv`. + if NETLINK_DEFAULT_BUF_SIZE - self.total_length < length { + self.error = Some(Error::with_message( + Errno::ENOBUFS, + "the receive buffer is full", + )); + return false; + } + + self.messages.push_back(message); + self.total_length += length; + + true } } + +impl MessageReceiver { + pub(super) fn enqueue_message(&self, message: Message) { + let is_ok = self.message_queue.lock().enqueue(message); + if is_ok { + self.pollee.notify(IoEvents::IN); + } else { + self.pollee.notify(IoEvents::ERR); + } + } +} + +const NETLINK_DEFAULT_BUF_SIZE: usize = 65536; diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs index 5a1d7d764..40c96d111 100644 --- a/kernel/src/net/socket/netlink/route/bound.rs +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -101,25 +101,16 @@ impl datagram_common::Bound for BoundNetlinkRoute { let mut receive_queue = self.receive_queue.lock(); - let Some(response) = receive_queue.peek() else { - return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"); - }; + receive_queue.dequeue_if(|response, response_len| { + let len = response_len.min(writer.sum_lens()); + response.write_to(writer)?; - let len = { - let max_len = writer.sum_lens(); - response.total_len().min(max_len) - }; + // TODO: The message can only come from kernel socket currently. + let remote = NetlinkSocketAddr::new_unspecified(); - response.write_to(writer)?; - - if !flags.contains(SendRecvFlags::MSG_PEEK) { - receive_queue.dequeue().unwrap(); - } - - // TODO: The message can only come from kernel socket currently. - let remote = NetlinkSocketAddr::new_unspecified(); - - Ok((len, remote)) + let should_dequeue = !flags.contains(SendRecvFlags::MSG_PEEK); + Ok((should_dequeue, (len, remote))) + }) } fn check_io_events(&self) -> IoEvents { diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs index 767978482..ac17cb918 100644 --- a/kernel/src/net/socket/netlink/table/mod.rs +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -4,7 +4,10 @@ use multicast::MulticastGroup; pub(super) use multicast::MulticastMessage; use spin::Once; -use super::addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS}; +use super::{ + addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS}, + receiver::QueueableMessage, +}; use crate::{ net::socket::netlink::{ addr::UNSPECIFIED_PORT, kobject_uevent::UeventMessage, receiver::MessageReceiver, @@ -46,7 +49,10 @@ pub trait SupportedNetlinkProtocol { socket_table.bind(Self::socket_table(), addr, receiver) } - fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()> { + fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()> + where + Self::Message: QueueableMessage, + { let socket_table = Self::socket_table().read(); socket_table.unicast(dst_port, message) } @@ -141,13 +147,17 @@ impl ProtocolSocketTable { Ok(BoundHandle::new(socket_table, port, addr.groups())) } - fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()> { + fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()> + where + Message: QueueableMessage, + { let Some(receiver) = self.unicast_sockets.get(&dst_port) else { // FIXME: Should we return error here? return Ok(()); }; + receiver.enqueue_message(message); - receiver.enqueue_message(message) + Ok(()) } fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<()> @@ -163,9 +173,7 @@ impl ProtocolSocketTable { let Some(receiver) = self.unicast_sockets.get(port_num) else { continue; }; - - // FIXME: Should we slightly ignore the error if the socket's buffer has no enough space? - receiver.enqueue_message(message.clone())?; + receiver.enqueue_message(message.clone()); } } diff --git a/kernel/src/net/socket/netlink/table/multicast.rs b/kernel/src/net/socket/netlink/table/multicast.rs index 3036e86d2..b09e8ec3c 100644 --- a/kernel/src/net/socket/netlink/table/multicast.rs +++ b/kernel/src/net/socket/netlink/table/multicast.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 -use crate::{net::socket::netlink::addr::PortNum, prelude::*}; +use crate::{ + net::socket::netlink::{addr::PortNum, receiver::QueueableMessage}, + prelude::*, +}; /// A netlink multicast group. /// @@ -34,4 +37,4 @@ impl MulticastGroup { } } -pub trait MulticastMessage: Clone {} +pub trait MulticastMessage: QueueableMessage + Clone {} diff --git a/test/src/apps/network/netlink_route.c b/test/src/apps/network/netlink_route.c index 66265a9d4..9e911382a 100644 --- a/test/src/apps/network/netlink_route.c +++ b/test/src/apps/network/netlink_route.c @@ -156,6 +156,14 @@ struct nl_req { char abuf[4]; }; +#define INIT_REQ(req) \ + memset(&req, 0, sizeof(req)); \ + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); \ + req.hdr.nlmsg_type = RTM_GETADDR; \ + req.hdr.nlmsg_flags = NLM_F_REQUEST; \ + req.hdr.nlmsg_seq = 1; \ + req.ifa.ifa_family = AF_UNSPEC; + FN_TEST(get_addr_error) { int sock_fd; @@ -170,12 +178,7 @@ FN_TEST(get_addr_error) // 1. Without NLM_F_DUMP flag struct nl_req req; - memset(&req, 0, sizeof(req)); - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST; - req.hdr.nlmsg_seq = 1; - req.ifa.ifa_family = AF_UNSPEC; + INIT_REQ(req); struct iovec iov = { &req, req.hdr.nlmsg_len }; struct msghdr msg = { &sa, sizeof(sa), &iov, 1, NULL, 0, 0 }; @@ -187,18 +190,20 @@ FN_TEST(get_addr_error) -EOPNOTSUPP); int found_new_addr; -#define TEST_KERNEL_RESPONSE \ - found_new_addr = 0; \ - while (1) { \ - size_t recv_len = \ - TEST_SUCC(recv(sock_fd, buffer, BUFFER_SIZE, 0)); \ - \ - int found_done = TEST_SUCC(find_new_addr_until_done( \ - buffer, recv_len, &found_new_addr)); \ - \ - if (found_done != 0) { \ - break; \ - } \ +#define TEST_KERNEL_RESPONSE \ + found_new_addr = 0; \ + while (1) { \ + size_t recv_len = \ + TEST_SUCC(recv(sock_fd, buffer, BUFFER_SIZE, 0)); \ + \ + int found_done = \ + TEST_RES(find_new_addr_until_done(buffer, recv_len, \ + &found_new_addr), \ + _ret >= 0); \ + \ + if (found_done != 0) { \ + break; \ + } \ } // 2. Invalid required index @@ -232,13 +237,7 @@ FN_TEST(bufsize_msgsize) sock_fd = TEST_SUCC( socket(AF_NETLINK, SOCK_RAW | SOCK_NONBLOCK, NETLINK_ROUTE)); - - memset(&req, 0, sizeof(req)); - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); - req.hdr.nlmsg_type = RTM_GETADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST; - req.hdr.nlmsg_seq = 1; - req.ifa.ifa_family = AF_UNSPEC; + INIT_REQ(req); // Send the request TEST_RES(send(sock_fd, &req, sizeof(req), 0), _ret == sizeof(req)); @@ -252,3 +251,47 @@ FN_TEST(bufsize_msgsize) TEST_SUCC(close(sock_fd)); } END_TEST() + +int fill_receive_buffer(int sock_fd, const struct nl_req *req) +{ + struct pollfd pfd = { .fd = sock_fd, .events = POLLIN | POLLOUT }; + int i; + + for (i = 0; i < 4096; ++i) { + if (send(sock_fd, req, sizeof(*req), 0) != sizeof(*req)) + return -1; + if (poll(&pfd, 1, 0) < 0) + return -1; + switch (pfd.revents) { + case POLLIN | POLLOUT: + continue; + case POLLIN | POLLOUT | POLLERR: + return 0; + default: + return -1; + } + } + + return -1; +} + +FN_TEST(enobufs) +{ + int sock_fd; + struct nl_req req; + + sock_fd = TEST_SUCC( + socket(AF_NETLINK, SOCK_RAW | SOCK_NONBLOCK, NETLINK_ROUTE)); + INIT_REQ(req); + + TEST_RES(fill_receive_buffer(sock_fd, &req), _ret >= 0); + + // Now the receive buffer is full. We can still send a new message, + // but the first `recv` should fail with `ENOBUFS`. + TEST_RES(send(sock_fd, &req, sizeof(req), 0), _ret == sizeof(req)); + TEST_ERRNO(recv(sock_fd, buffer, 1, 0), ENOBUFS); + TEST_SUCC(recv(sock_fd, buffer, 1, 0)); + + TEST_SUCC(close(sock_fd)); +} +END_TEST()