343 lines
12 KiB
Rust
343 lines
12 KiB
Rust
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec};
|
|
use core::{fmt::Debug, hint::spin_loop};
|
|
|
|
use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame};
|
|
use aster_util::{field_ptr, slot_vec::SlotVec};
|
|
use log::debug;
|
|
use pod::Pod;
|
|
|
|
use super::{
|
|
buffer::RxBuffer,
|
|
config::{VirtioVsockConfig, VsockFeatures},
|
|
connect::{ConnectionInfo, VsockEvent},
|
|
error::SocketError,
|
|
header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN},
|
|
VsockDeviceIrqHandler,
|
|
};
|
|
use crate::{
|
|
device::{
|
|
socket::{handle_recv_irq, register_device},
|
|
VirtioDeviceError,
|
|
},
|
|
queue::{QueueError, VirtQueue},
|
|
transport::VirtioTransport,
|
|
};
|
|
|
|
const QUEUE_SIZE: u16 = 64;
|
|
const QUEUE_RECV: u16 = 0;
|
|
const QUEUE_SEND: u16 = 1;
|
|
const QUEUE_EVENT: u16 = 2;
|
|
|
|
/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than `size_of::<VirtioVsockHdr>()`.
|
|
const RX_BUFFER_SIZE: usize = 512;
|
|
|
|
/// Vsock device driver
|
|
pub struct SocketDevice {
|
|
config: VirtioVsockConfig,
|
|
guest_cid: u64,
|
|
|
|
/// Virtqueue to receive packets.
|
|
send_queue: VirtQueue,
|
|
recv_queue: VirtQueue,
|
|
event_queue: VirtQueue,
|
|
|
|
rx_buffers: SlotVec<RxBuffer>,
|
|
transport: Box<dyn VirtioTransport>,
|
|
callbacks: Vec<Box<dyn VsockDeviceIrqHandler>>,
|
|
}
|
|
|
|
impl SocketDevice {
|
|
/// Create a new vsock device
|
|
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
|
|
let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut());
|
|
debug!("virtio_vsock_config = {:?}", virtio_vsock_config);
|
|
let guest_cid = field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low)
|
|
.read()
|
|
.unwrap() as u64
|
|
| (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high)
|
|
.read()
|
|
.unwrap() as u64)
|
|
<< 32;
|
|
|
|
let mut recv_queue = VirtQueue::new(QUEUE_RECV, QUEUE_SIZE, transport.as_mut())
|
|
.expect("createing recv queue fails");
|
|
let send_queue = VirtQueue::new(QUEUE_SEND, QUEUE_SIZE, transport.as_mut())
|
|
.expect("creating send queue fails");
|
|
let event_queue = VirtQueue::new(QUEUE_EVENT, QUEUE_SIZE, transport.as_mut())
|
|
.expect("creating event queue fails");
|
|
|
|
// Allocate and add buffers for the RX queue.
|
|
let mut rx_buffers = SlotVec::new();
|
|
for i in 0..QUEUE_SIZE {
|
|
let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE);
|
|
let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
|
|
assert_eq!(i, token);
|
|
assert_eq!(rx_buffers.put(rx_buffer) as u16, i);
|
|
}
|
|
|
|
if recv_queue.should_notify() {
|
|
debug!("notify receive queue");
|
|
recv_queue.notify();
|
|
}
|
|
|
|
let mut device = Self {
|
|
config: virtio_vsock_config.read().unwrap(),
|
|
guest_cid,
|
|
send_queue,
|
|
recv_queue,
|
|
event_queue,
|
|
rx_buffers,
|
|
transport,
|
|
callbacks: Vec::new(),
|
|
};
|
|
|
|
// Interrupt handler if vsock device config space changes
|
|
fn config_space_change(_: &TrapFrame) {
|
|
debug!("vsock device config space change");
|
|
}
|
|
|
|
// Interrupt handler if vsock device receives some packet.
|
|
fn handle_vsock_event(_: &TrapFrame) {
|
|
handle_recv_irq(super::DEVICE_NAME);
|
|
}
|
|
|
|
device
|
|
.transport
|
|
.register_cfg_callback(Box::new(config_space_change))
|
|
.unwrap();
|
|
device
|
|
.transport
|
|
.register_queue_callback(QUEUE_RECV, Box::new(handle_vsock_event), false)
|
|
.unwrap();
|
|
|
|
device.transport.finish_init();
|
|
|
|
register_device(
|
|
super::DEVICE_NAME.to_string(),
|
|
Arc::new(SpinLock::new(device)),
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Return the CID which has been assigned to this guest.
|
|
pub fn guest_cid(&self) -> u64 {
|
|
self.guest_cid
|
|
}
|
|
|
|
/// Send a connection request
|
|
pub fn request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
let header = VirtioVsockHdr {
|
|
op: VirtioVsockOp::Request as u16,
|
|
..connection_info.new_header(self.guest_cid)
|
|
};
|
|
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
}
|
|
|
|
/// Send a response to peer, if peer start a sending request
|
|
pub fn response(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
let header = VirtioVsockHdr {
|
|
op: VirtioVsockOp::Response as u16,
|
|
..connection_info.new_header(self.guest_cid)
|
|
};
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
}
|
|
|
|
/// Send a shutdown request
|
|
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
let header = VirtioVsockHdr {
|
|
op: VirtioVsockOp::Shutdown as u16,
|
|
..connection_info.new_header(self.guest_cid)
|
|
};
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
}
|
|
|
|
/// Send a reset request to peer
|
|
pub fn reset(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
let header = VirtioVsockHdr {
|
|
op: VirtioVsockOp::Rst as u16,
|
|
..connection_info.new_header(self.guest_cid)
|
|
};
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
}
|
|
|
|
/// Request the peer to send the credit info to us
|
|
pub fn credit_request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
let header = VirtioVsockHdr {
|
|
op: VirtioVsockOp::CreditRequest as u16,
|
|
..connection_info.new_header(self.guest_cid)
|
|
};
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
}
|
|
|
|
/// Tell the peer our credit info
|
|
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
let header = VirtioVsockHdr {
|
|
op: VirtioVsockOp::CreditUpdate as u16,
|
|
..connection_info.new_header(self.guest_cid)
|
|
};
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
}
|
|
|
|
fn send_packet_to_tx_queue(
|
|
&mut self,
|
|
header: &VirtioVsockHdr,
|
|
buffer: &[u8],
|
|
) -> Result<(), SocketError> {
|
|
let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?;
|
|
|
|
if self.send_queue.should_notify() {
|
|
self.send_queue.notify();
|
|
}
|
|
|
|
// Wait until the buffer is used
|
|
while !self.send_queue.can_pop() {
|
|
spin_loop();
|
|
}
|
|
|
|
self.send_queue.pop_used()?;
|
|
|
|
debug!("buffer in send_packet_to_tx_queue: {:?}", buffer);
|
|
Ok(())
|
|
}
|
|
|
|
fn check_peer_buffer_is_sufficient(
|
|
&mut self,
|
|
connection_info: &mut ConnectionInfo,
|
|
buffer_len: usize,
|
|
) -> Result<(), SocketError> {
|
|
debug!("connection info {:?}", connection_info);
|
|
debug!(
|
|
"peer free from peer: {:?}, buffer len : {:?}",
|
|
connection_info.peer_free(),
|
|
buffer_len
|
|
);
|
|
if connection_info.peer_free() as usize >= buffer_len {
|
|
Ok(())
|
|
} else {
|
|
// Request an update of the cached peer credit, if we haven't already done so, and tell
|
|
// the caller to try again later.
|
|
if !connection_info.has_pending_credit_request {
|
|
self.credit_request(connection_info)?;
|
|
connection_info.has_pending_credit_request = true;
|
|
}
|
|
Err(SocketError::InsufficientBufferSpaceInPeer)
|
|
}
|
|
}
|
|
|
|
/// Sends the buffer to the destination.
|
|
pub fn send(
|
|
&mut self,
|
|
buffer: &[u8],
|
|
connection_info: &mut ConnectionInfo,
|
|
) -> Result<(), SocketError> {
|
|
self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
|
|
|
|
let len = buffer.len() as u32;
|
|
let header = VirtioVsockHdr {
|
|
op: VirtioVsockOp::Rw as u16,
|
|
len,
|
|
..connection_info.new_header(self.guest_cid)
|
|
};
|
|
connection_info.tx_cnt += len;
|
|
self.send_packet_to_tx_queue(&header, buffer)
|
|
}
|
|
|
|
/// Receive bytes from peer, returns the header
|
|
pub fn receive(
|
|
&mut self,
|
|
buffer: &mut [u8],
|
|
// connection_info: &mut ConnectionInfo,
|
|
) -> Result<VirtioVsockHdr, SocketError> {
|
|
let (token, len) = self.recv_queue.pop_used()?;
|
|
debug!(
|
|
"receive packet in rx_queue: token = {}, len = {}",
|
|
token, len
|
|
);
|
|
let mut rx_buffer = self
|
|
.rx_buffers
|
|
.remove(token as usize)
|
|
.ok_or(QueueError::WrongToken)?;
|
|
rx_buffer.set_packet_len(RX_BUFFER_SIZE);
|
|
|
|
let (header, payload) = read_header_and_body(rx_buffer.buf())?;
|
|
// The length written should be equal to len(header)+len(packet)
|
|
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
|
|
debug!("Received packet {:?}. Op {:?}", header, header.op());
|
|
debug!("body is {:?}", payload);
|
|
|
|
assert!(buffer.len() >= payload.len());
|
|
buffer[..payload.len()].copy_from_slice(payload);
|
|
|
|
self.add_rx_buffer(rx_buffer, token)?;
|
|
Ok(header)
|
|
}
|
|
|
|
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
|
|
pub fn poll(
|
|
&mut self,
|
|
handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>, SocketError>,
|
|
) -> Result<Option<VsockEvent>, SocketError> {
|
|
// Return None if there is no pending packet.
|
|
if !self.recv_queue.can_pop() {
|
|
return Ok(None);
|
|
}
|
|
let mut body = RxBuffer::new(RX_BUFFER_SIZE);
|
|
let header = self.receive(body.buf_mut())?;
|
|
|
|
VsockEvent::from_header(&header).and_then(|event| handler(event, body.buf()))
|
|
}
|
|
|
|
/// Add a used rx buffer to recv queue,@index is only to check the correctness
|
|
fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> {
|
|
let token = self.recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
|
|
assert_eq!(index, token);
|
|
assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none());
|
|
if self.recv_queue.should_notify() {
|
|
self.recv_queue.notify();
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Negotiate features for the device specified bits 0~23
|
|
pub(crate) fn negotiate_features(features: u64) -> u64 {
|
|
let device_features = VsockFeatures::from_bits_truncate(features);
|
|
let supported_features = VsockFeatures::supported_features();
|
|
let vsock_features = device_features & supported_features;
|
|
debug!("features negotiated: {:?}", vsock_features);
|
|
vsock_features.bits()
|
|
}
|
|
}
|
|
|
|
impl Debug for SocketDevice {
|
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|
f.debug_struct("SocketDevice")
|
|
.field("config", &self.config)
|
|
.field("guest_cid", &self.guest_cid)
|
|
.field("send_queue", &self.send_queue)
|
|
.field("recv_queue", &self.recv_queue)
|
|
.field("event_queue", &self.event_queue)
|
|
.field("transport", &self.transport)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]), SocketError> {
|
|
// Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
|
|
let header = VirtioVsockHdr::from_bytes(&buffer[..VIRTIO_VSOCK_HDR_LEN]);
|
|
let body_length = header.len() as usize;
|
|
|
|
// This could fail if the device returns an unreasonably long body length.
|
|
let data_end = VIRTIO_VSOCK_HDR_LEN
|
|
.checked_add(body_length)
|
|
.ok_or(SocketError::InvalidNumber)?;
|
|
// This could fail if the device returns a body length longer than the buffer we gave it.
|
|
let data = buffer
|
|
.get(VIRTIO_VSOCK_HDR_LEN..data_end)
|
|
.ok_or(SocketError::BufferTooShort)?;
|
|
Ok((header, data))
|
|
}
|