Truncate netlink messages

This commit is contained in:
Ruihan Li 2025-06-11 12:32:20 +08:00 committed by Jianfeng Jiang
parent 0e8106abfa
commit deab9b6f72
9 changed files with 89 additions and 74 deletions

View File

@ -46,11 +46,9 @@ impl UeventMessage {
/// Writes the uevent to the given `writer`.
pub(super) fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> {
// FIXME: If the message can be truncated, we should avoid returning an error.
if self.uevent.len() > writer.sum_lens() {
return_errno_with_message!(Errno::EFAULT, "the writer length is too small");
}
writer.write(&mut VmReader::from(self.uevent.as_bytes()))?;
let _nbytes = writer.write(&mut VmReader::from(self.uevent.as_bytes()))?;
// `_nbytes` may be smaller than the message size. We ignore it to truncate the message.
Ok(())
}
}

View File

@ -107,7 +107,7 @@ pub trait Attribute: Debug + Send + Sync {
total_len -= attr.total_len();
let padding_len = attr.padding_len().min(total_len);
reader.skip(padding_len);
reader.skip_some(padding_len);
total_len -= padding_len;
res.push(attr);
@ -123,11 +123,11 @@ pub trait Attribute: Debug + Send + Sync {
len: self.total_len() as u16,
};
writer.write_val(&header)?;
writer.write_val_trunc(&header)?;
writer.write(&mut VmReader::from(self.payload_as_bytes()))?;
let padding_len = self.padding_len();
writer.skip(padding_len);
writer.skip_some(padding_len);
Ok(())
}

View File

@ -1,10 +1,8 @@
// SPDX-License-Identifier: MPL-2.0
use align_ext::AlignExt;
use super::{header::CMsgSegHdr, SegmentBody};
use crate::{
net::socket::netlink::message::{attr::Attribute, NLMSG_ALIGN},
net::socket::netlink::message::attr::Attribute,
prelude::*,
util::{MultiRead, MultiWrite},
};
@ -54,7 +52,6 @@ impl<Body: SegmentBody, Attr: Attribute> SegmentCommon<Body, Attr> {
Error: From<<Body::CType as TryInto<Body>>::Error>,
{
let (body, remain_len) = Body::read_from(&header, reader).unwrap();
let attrs = Attr::read_all_from(reader, remain_len)?;
Ok(Self {
@ -65,14 +62,8 @@ impl<Body: SegmentBody, Attr: Attribute> SegmentCommon<Body, Attr> {
}
pub fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> {
// FIXME: If the message can be truncated, we should avoid returning an error.
// Furthermore, we need to check the Linux behavior to determine whether to return an error
// if the writer is not large enough to accommodate the final padding bytes.
if writer.sum_lens() < (self.header.len as usize).align_up(NLMSG_ALIGN) {
return_errno_with_message!(Errno::EFAULT, "the writer length is too small");
}
writer.write_val_trunc(&self.header)?;
writer.write_val(&self.header)?;
self.body.write_to(writer)?;
for attr in self.attrs.iter() {
attr.write_to(writer)?;

View File

@ -43,12 +43,12 @@ pub trait SegmentBody: Sized + Clone + Copy {
// Read the body.
let (c_type, padding_len) = if remaining_len >= size_of::<Self::CType>() {
let c_type = reader.read_val::<Self::CType>()?;
let c_type = reader.read_val_opt::<Self::CType>()?.unwrap();
remaining_len -= size_of_val(&c_type);
(c_type, Self::padding_len())
} else if remaining_len >= size_of::<Self::CLegacyType>() {
let legacy = reader.read_val::<Self::CLegacyType>()?;
let legacy = reader.read_val_opt::<Self::CLegacyType>()?.unwrap();
remaining_len -= size_of_val(&legacy);
(Self::CType::from(legacy), Self::lecacy_padding_len())
@ -58,7 +58,7 @@ pub trait SegmentBody: Sized + Clone + Copy {
// Skip the padding bytes.
let padding_len = padding_len.min(remaining_len);
reader.skip(padding_len);
reader.skip_some(padding_len);
remaining_len -= padding_len;
let body = c_type.try_into()?;
@ -68,11 +68,12 @@ pub trait SegmentBody: Sized + Clone + Copy {
fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> {
// Write the body.
let c_body = Self::CType::from(*self);
writer.write_val(&c_body)?;
writer.write_val_trunc(&c_body)?;
// Skip the padding bytes.
let padding_len = Self::padding_len();
writer.skip(padding_len);
writer.skip_some(padding_len);
Ok(())
}

View File

@ -62,11 +62,12 @@ impl Attribute for AddrAttr {
where
Self: Sized,
{
let header = reader.read_val::<CAttrHeader>()?;
let header = reader.read_val_opt::<CAttrHeader>()?.unwrap();
// TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored.
let res = match AddrAttrClass::try_from(header.type_())? {
AddrAttrClass::ADDRESS => Self::Address(reader.read_val()?),
AddrAttrClass::LOCAL => Self::Local(reader.read_val()?),
AddrAttrClass::ADDRESS => Self::Address(reader.read_val_opt()?.unwrap()),
AddrAttrClass::LOCAL => Self::Local(reader.read_val_opt()?.unwrap()),
AddrAttrClass::LABEL => Self::Label(reader.read_cstring_with_max_len(IFNAME_SIZE)?),
class => {
// FIXME: Netlink should ignore all unknown attributes.

View File

@ -122,14 +122,15 @@ impl Attribute for LinkAttr {
where
Self: Sized,
{
let header = reader.read_val::<CAttrHeader>()?;
let header = reader.read_val_opt::<CAttrHeader>()?.unwrap();
// TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored.
let res = match LinkAttrClass::try_from(header.type_())? {
LinkAttrClass::IFNAME => Self::Name(reader.read_cstring_with_max_len(IFNAME_SIZE)?),
LinkAttrClass::MTU => Self::Mtu(reader.read_val()?),
LinkAttrClass::TXQLEN => Self::TxqLen(reader.read_val()?),
LinkAttrClass::LINKMODE => Self::LinkMode(reader.read_val()?),
LinkAttrClass::EXT_MASK => Self::ExtMask(reader.read_val()?),
LinkAttrClass::MTU => Self::Mtu(reader.read_val_opt()?.unwrap()),
LinkAttrClass::TXQLEN => Self::TxqLen(reader.read_val_opt()?.unwrap()),
LinkAttrClass::LINKMODE => Self::LinkMode(reader.read_val_opt()?.unwrap()),
LinkAttrClass::EXT_MASK => Self::ExtMask(reader.read_val_opt()?.unwrap()),
class => {
// FIXME: Netlink should ignore all unknown attributes.
// But how to decide the payload type if the class is unknown?

View File

@ -84,7 +84,9 @@ impl ProtocolSegment for RtnlSegment {
}
fn read_from(reader: &mut dyn MultiRead) -> Result<Self> {
let header = reader.read_val::<CMsgSegHdr>()?;
let header = reader
.read_val_opt::<CMsgSegHdr>()?
.ok_or_else(|| Error::with_message(Errno::EINVAL, "the reader length is too small"))?;
let segment = match CSegmentType::try_from(header.type_)? {
CSegmentType::GETLINK => RtnlSegment::GetLink(LinkSegment::read_from(header, reader)?),

View File

@ -155,12 +155,9 @@ pub trait MultiRead: ReadCString {
self.sum_lens() == 0
}
/// Skips the first `nbytes` bytes of data.
///
/// # Panics
///
/// If `nbytes` is greater that [`MultiRead::sum_lens`], this method will panic.
fn skip(&mut self, nbytes: usize);
/// Skips the first `nbytes` bytes of data, or skips to the end if the readers have
/// insufficient bytes.
fn skip_some(&mut self, nbytes: usize);
}
/// Trait defining the write behavior for a collection of [`VmWriter`]s.
@ -185,12 +182,9 @@ pub trait MultiWrite {
self.sum_lens() == 0
}
/// Skips the first `nbytes` bytes of space.
///
/// # Panics
///
/// If `nbytes` is greater that [`MultiWrite::sum_lens`], this method will panic.
fn skip(&mut self, nbytes: usize);
/// Skips the first `nbytes` bytes of data, or skips to the end if the writers have
/// insufficient bytes.
fn skip_some(&mut self, nbytes: usize);
}
impl MultiRead for VmReaderArray<'_> {
@ -211,7 +205,7 @@ impl MultiRead for VmReaderArray<'_> {
self.0.iter().map(|vm_reader| vm_reader.remain()).sum()
}
fn skip(&mut self, mut nbytes: usize) {
fn skip_some(&mut self, mut nbytes: usize) {
for reader in &mut self.0 {
let bytes_to_skip = reader.remain().min(nbytes);
reader.skip(bytes_to_skip);
@ -221,11 +215,6 @@ impl MultiRead for VmReaderArray<'_> {
return;
}
}
panic!(
"the readers are exhausted but there are {} bytes remaining to skip",
nbytes
);
}
}
@ -238,16 +227,22 @@ impl MultiRead for VmReader<'_> {
self.remain()
}
fn skip(&mut self, nbytes: usize) {
VmReader::skip(self, nbytes);
fn skip_some(&mut self, nbytes: usize) {
self.skip(self.remain().min(nbytes));
}
}
impl dyn MultiRead + '_ {
pub fn read_val<T: Pod>(&mut self) -> Result<T> {
/// Reads a `T` value, returning a `None` if the readers have insufficient bytes.
pub fn read_val_opt<T: Pod>(&mut self) -> Result<Option<T>> {
let mut val = T::new_zeroed();
self.read(&mut VmWriter::from(val.as_bytes_mut()))?;
Ok(val)
let nbytes = self.read(&mut VmWriter::from(val.as_bytes_mut()))?;
if nbytes == size_of::<T>() {
Ok(Some(val))
} else {
Ok(None)
}
}
}
@ -269,7 +264,7 @@ impl MultiWrite for VmWriterArray<'_> {
self.0.iter().map(|vm_writer| vm_writer.avail()).sum()
}
fn skip(&mut self, mut nbytes: usize) {
fn skip_some(&mut self, mut nbytes: usize) {
for writer in &mut self.0 {
let bytes_to_skip = writer.avail().min(nbytes);
writer.skip(bytes_to_skip);
@ -279,11 +274,6 @@ impl MultiWrite for VmWriterArray<'_> {
return;
}
}
panic!(
"the writers are exhausted but there are {} bytes remaining to skip",
nbytes
);
}
}
@ -296,14 +286,17 @@ impl MultiWrite for VmWriter<'_> {
self.avail()
}
fn skip(&mut self, nbytes: usize) {
VmWriter::skip(self, nbytes);
fn skip_some(&mut self, nbytes: usize) {
self.skip(self.avail().min(nbytes));
}
}
impl dyn MultiWrite + '_ {
pub fn write_val<T: Pod>(&mut self, val: &T) -> Result<()> {
self.write(&mut VmReader::from(val.as_bytes()))?;
/// Writes a `T` value, truncating the value if the writers have insufficient bytes.
pub fn write_val_trunc<T: Pod>(&mut self, val: &T) -> Result<()> {
let _nbytes = self.write(&mut VmReader::from(val.as_bytes()))?;
// `_nbytes` may be smaller than the value size. We ignore it to truncate the value.
Ok(())
}
}

View File

@ -146,18 +146,18 @@ int find_new_addr_until_done(char *buffer, size_t len, int *found_new_addr)
return 0;
}
#define BUFFER_SIZE 8192
char buffer[BUFFER_SIZE];
struct nl_req {
struct nlmsghdr hdr;
struct ifaddrmsg ifa;
};
FN_TEST(get_addr_error)
{
#define BUFFER_SIZE 8192
struct nl_req {
struct nlmsghdr hdr;
struct ifaddrmsg ifa;
};
int sock_fd;
struct sockaddr_nl sa;
char buffer[BUFFER_SIZE];
sock_fd = TEST_SUCC(socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE));
@ -223,3 +223,31 @@ FN_TEST(get_addr_error)
TEST_SUCC(close(sock_fd));
}
END_TEST()
FN_TEST(bufsize_msgsize)
{
int sock_fd;
struct nl_req req;
sock_fd = TEST_SUCC(
socket(AF_NETLINK, SOCK_RAW | SOCK_NONBLOCK, NETLINK_ROUTE));
memset(&req, 0, sizeof(req));
req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg));
req.hdr.nlmsg_type = RTM_GETADDR;
req.hdr.nlmsg_flags = NLM_F_REQUEST;
req.hdr.nlmsg_seq = 1;
req.ifa.ifa_family = AF_UNSPEC;
// Send the request
TEST_RES(send(sock_fd, &req, sizeof(req), 0), _ret == sizeof(req));
// The buffer size is too short, but it still succeeds
TEST_SUCC(recv(sock_fd, buffer, 1, 0));
// The truncated message is now lost
TEST_ERRNO(recv(sock_fd, buffer, BUFFER_SIZE, 0), EAGAIN);
TEST_SUCC(close(sock_fd));
}
END_TEST()