Add mlsdisk as a component

Co-authored-by: Shaowei Song <songshaowei.ssw@antgroup.com>
This commit is contained in:
Qingsong Chen 2024-12-27 11:49:46 +00:00 committed by Tate, Hongliang Tian
parent 6e691d5838
commit 56a137dc56
45 changed files with 13832 additions and 182 deletions

602
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ members = [
"kernel/comps/network",
"kernel/comps/softirq",
"kernel/comps/logger",
"kernel/comps/mlsdisk",
"kernel/comps/time",
"kernel/comps/virtio",
"kernel/libs/cpio-decoder",

View File

@ -10,6 +10,7 @@ logger = { name = "aster-logger" }
time = { name = "aster-time" }
framebuffer = { name = "aster-framebuffer" }
network = { name = "aster-network" }
mlsdisk = { name = "aster-mlsdisk" }
[whitelist]
[whitelist.nix.main]

View File

@ -149,6 +149,7 @@ OSDK_CRATES := \
kernel/comps/network \
kernel/comps/softirq \
kernel/comps/logger \
kernel/comps/mlsdisk \
kernel/comps/time \
kernel/comps/virtio \
kernel/libs/aster-util \

View File

@ -14,6 +14,7 @@ aster-console = { path = "comps/console" }
aster-framebuffer = { path = "comps/framebuffer" }
aster-softirq = { path = "comps/softirq" }
aster-logger = { path = "comps/logger" }
aster-mlsdisk = { path = "comps/mlsdisk" }
aster-time = { path = "comps/time" }
aster-virtio = { path = "comps/virtio" }
aster-rights = { path = "libs/aster-rights" }

View File

@ -0,0 +1,22 @@
[package]
name = "aster-mlsdisk"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
inherit-methods-macro = {git = "https://github.com/asterinas/inherit-methods-macro", rev = "98f7e3e"}
ostd-pod = { git = "https://github.com/asterinas/ostd-pod", rev = "c4644be", version = "0.1.1" }
aster-block = { path = "../block" }
ostd = { path = "../../../ostd" }
aes-gcm = { version = "0.9.4", features = ["force-soft"] }
bittle = "0.5.6"
ctr = "0.8.0"
hashbrown = { version = "0.14.3", features = ["serde"] }
lending-iterator = "0.1.7"
log = "0.4"
lru = "0.12.3"
postcard = "1.0.6"
serde = { version = "1.0.192", default-features = false, features = ["alloc", "derive"] }
static_assertions = "1.1.0"

View File

@ -0,0 +1,95 @@
// SPDX-License-Identifier: MPL-2.0
use core::fmt;
/// The error types used in this crate.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Errno {
/// Transaction aborted.
TxAborted,
/// Not found.
NotFound,
/// Invalid arguments.
InvalidArgs,
/// Out of memory.
OutOfMemory,
/// Out of disk space.
OutOfDisk,
/// IO error.
IoFailed,
/// Permission denied.
PermissionDenied,
/// Unsupported.
Unsupported,
/// OS-specific unknown error.
OsSpecUnknown,
/// Encryption operation failed.
EncryptFailed,
/// Decryption operation failed.
DecryptFailed,
/// MAC (Message Authentication Code) mismatched.
MacMismatched,
/// Not aligned to `BLOCK_SIZE`.
NotBlockSizeAligned,
/// Try lock failed.
TryLockFailed,
}
/// The error with an error type and an error message used in this crate.
#[derive(Clone, Debug)]
pub struct Error {
errno: Errno,
msg: Option<&'static str>,
}
impl Error {
/// Creates a new error with the given error type and no error message.
pub const fn new(errno: Errno) -> Self {
Error { errno, msg: None }
}
/// Creates a new error with the given error type and the error message.
pub const fn with_msg(errno: Errno, msg: &'static str) -> Self {
Error {
errno,
msg: Some(msg),
}
}
/// Returns the error type.
pub fn errno(&self) -> Errno {
self.errno
}
}
impl From<Errno> for Error {
fn from(errno: Errno) -> Self {
Error::new(errno)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl fmt::Display for Errno {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
#[macro_export]
macro_rules! return_errno {
($errno: expr) => {
return core::result::Result::Err($crate::Error::new($errno))
};
}
#[macro_export]
macro_rules! return_errno_with_msg {
($errno: expr, $msg: expr) => {
return core::result::Result::Err($crate::Error::with_msg($errno, $msg))
};
}

View File

@ -0,0 +1,243 @@
// SPDX-License-Identifier: MPL-2.0
//! This module provides API to represent buffers whose
//! sizes are block aligned. The advantage of using the
//! APIs provided this module over Rust std's counterparts
//! is to ensure the invariance of block-aligned length
//! at type level, eliminating the need for runtime check.
//!
//! There are three main types:
//! * `Buf`: A owned buffer backed by `Pages`, whose length is
//! a multiple of the block size.
//! * `BufRef`: An immutably-borrowed buffer whose length
//! is a multiple of the block size.
//! * `BufMut`: A mutably-borrowed buffer whose length is
//! a multiple of the block size.
//!
//! The basic usage is simple: replace the usage of `Box<[u8]>`
//! with `Buf`, `&[u8]` with `BufRef<[u8]>`,
//! and `&mut [u8]` with `BufMut<[u8]>`.
use alloc::vec;
use core::convert::TryFrom;
use lending_iterator::prelude::*;
use super::BLOCK_SIZE;
use crate::prelude::*;
/// A owned buffer whose length is a multiple of the block size.
pub struct Buf(Vec<u8>);
impl Buf {
/// Allocate specific number of blocks as memory buffer.
pub fn alloc(num_blocks: usize) -> Result<Self> {
if num_blocks == 0 {
return_errno_with_msg!(
InvalidArgs,
"num_blocks must be greater than 0 for allocation"
)
}
let buffer = vec![0; num_blocks * BLOCK_SIZE];
Ok(Self(buffer))
}
/// Returns the number of blocks of owned buffer.
pub fn nblocks(&self) -> usize {
self.0.len() / BLOCK_SIZE
}
/// Returns the immutable slice of owned buffer.
pub fn as_slice(&self) -> &[u8] {
self.0.as_slice()
}
/// Returns the mutable slice of owned buffer.
pub fn as_mut_slice(&mut self) -> &mut [u8] {
self.0.as_mut_slice()
}
/// Converts to immutably-borrowed buffer `BufRef`.
pub fn as_ref(&self) -> BufRef<'_> {
BufRef(self.as_slice())
}
/// Coverts to mutably-borrowed buffer `BufMut`.
pub fn as_mut(&mut self) -> BufMut<'_> {
BufMut(self.as_mut_slice())
}
}
/// An immutably-borrowed buffer whose length is a multiple of the block size.
#[derive(Clone, Copy)]
pub struct BufRef<'a>(&'a [u8]);
impl BufRef<'_> {
/// Returns the immutable slice of borrowed buffer.
pub fn as_slice(&self) -> &[u8] {
self.0
}
/// Returns the number of blocks of borrowed buffer.
pub fn nblocks(&self) -> usize {
self.0.len() / BLOCK_SIZE
}
/// Returns an iterator for immutable buffers of `BLOCK_SIZE`.
pub fn iter(&self) -> BufIter<'_> {
BufIter {
buf: BufRef(self.as_slice()),
offset: 0,
}
}
}
impl<'a> TryFrom<&'a [u8]> for BufRef<'a> {
type Error = crate::error::Error;
fn try_from(buf: &'a [u8]) -> Result<Self> {
if buf.is_empty() {
return_errno_with_msg!(InvalidArgs, "empty buf in `BufRef::try_from`");
}
if buf.len() % BLOCK_SIZE != 0 {
return_errno_with_msg!(
NotBlockSizeAligned,
"buf not block size aligned `BufRef::try_from`"
);
}
let new_self = Self(buf);
Ok(new_self)
}
}
/// A mutably-borrowed buffer whose length is a multiple of the block size.
pub struct BufMut<'a>(&'a mut [u8]);
impl BufMut<'_> {
/// Returns the immutable slice of borrowed buffer.
pub fn as_slice(&self) -> &[u8] {
self.0
}
/// Returns the mutable slice of borrowed buffer.
pub fn as_mut_slice(&mut self) -> &mut [u8] {
self.0
}
/// Returns the number of blocks of borrowed buffer.
pub fn nblocks(&self) -> usize {
self.0.len() / BLOCK_SIZE
}
/// Returns an iterator for immutable buffers of `BLOCK_SIZE`.
pub fn iter(&self) -> BufIter<'_> {
BufIter {
buf: BufRef(self.as_slice()),
offset: 0,
}
}
/// Returns an iterator for mutable buffers of `BLOCK_SIZE`.
pub fn iter_mut(&mut self) -> BufIterMut<'_> {
BufIterMut {
buf: BufMut(self.as_mut_slice()),
offset: 0,
}
}
}
impl<'a> TryFrom<&'a mut [u8]> for BufMut<'a> {
type Error = crate::error::Error;
fn try_from(buf: &'a mut [u8]) -> Result<Self> {
if buf.is_empty() {
return_errno_with_msg!(InvalidArgs, "empty buf in `BufMut::try_from`");
}
if buf.len() % BLOCK_SIZE != 0 {
return_errno_with_msg!(
NotBlockSizeAligned,
"buf not block size aligned `BufMut::try_from`"
);
}
let new_self = Self(buf);
Ok(new_self)
}
}
/// Iterator for immutable buffers of `BLOCK_SIZE`.
pub struct BufIter<'a> {
buf: BufRef<'a>,
offset: usize,
}
impl<'a> Iterator for BufIter<'a> {
type Item = BufRef<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.offset >= self.buf.0.len() {
return None;
}
let offset = self.offset;
self.offset += BLOCK_SIZE;
BufRef::try_from(&self.buf.0[offset..offset + BLOCK_SIZE]).ok()
}
}
/// Iterator for mutable buffers of `BLOCK_SIZE`.
pub struct BufIterMut<'a> {
buf: BufMut<'a>,
offset: usize,
}
#[gat]
impl LendingIterator for BufIterMut<'_> {
type Item<'next> = BufMut<'next>;
fn next(&mut self) -> Option<Self::Item<'_>> {
if self.offset >= self.buf.0.len() {
return None;
}
let offset = self.offset;
self.offset += BLOCK_SIZE;
BufMut::try_from(&mut self.buf.0[offset..offset + BLOCK_SIZE]).ok()
}
}
#[cfg(test)]
mod tests {
use lending_iterator::LendingIterator;
use super::{Buf, BufMut, BufRef, BLOCK_SIZE};
fn iterate_buf_ref<'a>(buf: BufRef<'a>) {
for block in buf.iter() {
assert_eq!(block.as_slice().len(), BLOCK_SIZE);
assert_eq!(block.nblocks(), 1);
}
}
fn iterate_buf_mut<'a>(mut buf: BufMut<'a>) {
let mut iter_mut = buf.iter_mut();
while let Some(mut block) = iter_mut.next() {
assert_eq!(block.as_mut_slice().len(), BLOCK_SIZE);
assert_eq!(block.nblocks(), 1);
}
}
#[test]
fn buf() {
let mut buf = Buf::alloc(10).unwrap();
assert_eq!(buf.nblocks(), 10);
assert_eq!(buf.as_slice().len(), 10 * BLOCK_SIZE);
iterate_buf_ref(buf.as_ref());
iterate_buf_mut(buf.as_mut());
let mut buf = [0u8; BLOCK_SIZE];
iterate_buf_ref(BufRef::try_from(buf.as_slice()).unwrap());
iterate_buf_mut(BufMut::try_from(buf.as_mut_slice()).unwrap());
}
}

View File

@ -0,0 +1,133 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicUsize, Ordering};
use inherit_methods_macro::inherit_methods;
use super::{Buf, BufMut, BufRef};
use crate::{os::Mutex, prelude::*};
/// A log of data blocks that can support random reads and append-only
/// writes.
///
/// # Thread safety
///
/// `BlockLog` is a data structure of interior mutability.
/// It is ok to perform I/O on a `BlockLog` concurrently in multiple threads.
/// `BlockLog` promises the serialization of the append operations, i.e.,
/// concurrent appends are carried out as if they are done one by one.
pub trait BlockLog: Sync + Send {
/// Read one or multiple blocks at a specified position.
fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>;
/// Append one or multiple blocks at the end,
/// returning the ID of the first newly-appended block.
fn append(&self, buf: BufRef) -> Result<BlockId>;
/// Ensure that blocks are persisted to the disk.
fn flush(&self) -> Result<()>;
/// Returns the number of blocks.
fn nblocks(&self) -> usize;
}
macro_rules! impl_blocklog_for {
($typ:ty,$from:tt) => {
#[inherit_methods(from = $from)]
impl<T: BlockLog> BlockLog for $typ {
fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>;
fn append(&self, buf: BufRef) -> Result<BlockId>;
fn flush(&self) -> Result<()>;
fn nblocks(&self) -> usize;
}
};
}
impl_blocklog_for!(&T, "(**self)");
impl_blocklog_for!(&mut T, "(**self)");
impl_blocklog_for!(Box<T>, "(**self)");
impl_blocklog_for!(Arc<T>, "(**self)");
/// An in-memory log that impls `BlockLog`.
pub struct MemLog {
log: Mutex<Buf>,
append_pos: AtomicUsize,
}
impl BlockLog for MemLog {
fn read(&self, pos: BlockId, mut buf: BufMut) -> Result<()> {
let nblocks = buf.nblocks();
if pos + nblocks > self.nblocks() {
return_errno_with_msg!(InvalidArgs, "read range out of bound");
}
let log = self.log.lock();
let read_buf = &log.as_slice()[Self::offset(pos)..Self::offset(pos) + nblocks * BLOCK_SIZE];
buf.as_mut_slice().copy_from_slice(read_buf);
Ok(())
}
fn append(&self, buf: BufRef) -> Result<BlockId> {
let nblocks = buf.nblocks();
let mut log = self.log.lock();
let pos = self.append_pos.load(Ordering::Acquire);
if pos + nblocks > log.nblocks() {
return_errno_with_msg!(InvalidArgs, "append range out of bound");
}
let write_buf =
&mut log.as_mut_slice()[Self::offset(pos)..Self::offset(pos) + nblocks * BLOCK_SIZE];
write_buf.copy_from_slice(buf.as_slice());
self.append_pos.fetch_add(nblocks, Ordering::Release);
Ok(pos)
}
fn flush(&self) -> Result<()> {
Ok(())
}
fn nblocks(&self) -> usize {
self.append_pos.load(Ordering::Acquire)
}
}
impl MemLog {
pub fn create(num_blocks: usize) -> Result<Self> {
Ok(Self {
log: Mutex::new(Buf::alloc(num_blocks)?),
append_pos: AtomicUsize::new(0),
})
}
fn offset(pos: BlockId) -> usize {
pos * BLOCK_SIZE
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mem_log() -> Result<()> {
let total_blocks = 64;
let append_nblocks = 8;
let mem_log = MemLog::create(total_blocks)?;
assert_eq!(mem_log.nblocks(), 0);
let mut append_buf = Buf::alloc(append_nblocks)?;
let content = 5_u8;
append_buf.as_mut_slice().fill(content);
let append_pos = mem_log.append(append_buf.as_ref())?;
assert_eq!(append_pos, 0);
assert_eq!(mem_log.nblocks(), append_nblocks);
mem_log.flush()?;
let mut read_buf = Buf::alloc(1)?;
let read_pos = 7 as BlockId;
mem_log.read(read_pos, read_buf.as_mut())?;
assert_eq!(
read_buf.as_slice(),
&append_buf.as_slice()[read_pos * BLOCK_SIZE..(read_pos + 1) * BLOCK_SIZE]
);
Ok(())
}
}

View File

@ -0,0 +1,114 @@
// SPDX-License-Identifier: MPL-2.0
use super::{BlockLog, BlockSet, BufMut, BufRef};
use crate::{os::Mutex, prelude::*};
/// `BlockRing<S>` emulates a blocks log (`BlockLog`) with infinite
/// storage capacity by using a block set (`S: BlockSet`) of finite storage
/// capacity.
///
/// `BlockRing<S>` uses the entire storage space provided by the underlying
/// block set (`S`) for user data, maintaining no extra metadata.
/// Having no metadata, `BlockRing<S>` has to put three responsibilities to
/// its user:
///
/// 1. Tracking the valid block range for read.
/// `BlockRing<S>` accepts reads at any position regardless of whether the
/// position refers to a valid block. It blindly redirects the read request to
/// the underlying block set after moduloing the target position by the
/// size of the block set.
///
/// 2. Setting the cursor for appending new blocks.
/// `BlockRing<S>` won't remember the progress of writing blocks after reboot.
/// Thus, after a `BlockRing<S>` is instantiated, the user must specify the
/// append cursor (using the `set_cursor` method) before appending new blocks.
///
/// 3. Avoiding overriding valid data blocks mistakenly.
/// As the underlying storage is used in a ring buffer style, old
/// blocks must be overridden to accommodate new blocks. The user must ensure
/// that the underlying storage is big enough to avoid overriding any useful
/// data.
pub struct BlockRing<S> {
storage: S,
// The cursor for appending new blocks
cursor: Mutex<Option<BlockId>>,
}
impl<S: BlockSet> BlockRing<S> {
/// Creates a new instance.
pub fn new(storage: S) -> Self {
Self {
storage,
cursor: Mutex::new(None),
}
}
/// Set the cursor for appending new blocks.
///
/// # Panics
///
/// Calling the `append` method without setting the append cursor first
/// via this method `set_cursor` causes panic.
pub fn set_cursor(&self, new_cursor: BlockId) {
*self.cursor.lock() = Some(new_cursor);
}
// Return a reference to the underlying storage.
pub fn storage(&self) -> &S {
&self.storage
}
}
impl<S: BlockSet> BlockLog for BlockRing<S> {
fn read(&self, pos: BlockId, buf: BufMut) -> Result<()> {
let pos = pos % self.storage.nblocks();
self.storage.read(pos, buf)
}
fn append(&self, buf: BufRef) -> Result<BlockId> {
let cursor = self
.cursor
.lock()
.expect("cursor must be set before appending new blocks");
let pos = cursor % self.storage.nblocks();
let new_cursor = cursor + buf.nblocks();
self.storage.write(pos, buf)?;
self.set_cursor(new_cursor);
Ok(cursor)
}
fn flush(&self) -> Result<()> {
self.storage.flush()
}
fn nblocks(&self) -> usize {
self.cursor.lock().unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::BlockRing;
use crate::layers::bio::{BlockLog, Buf, MemDisk};
#[test]
fn block_ring() {
let num_blocks = 16;
let disk = MemDisk::create(num_blocks).unwrap();
let block_ring = BlockRing::new(disk);
block_ring.set_cursor(num_blocks);
assert_eq!(block_ring.nblocks(), num_blocks);
let mut append_buf = Buf::alloc(1).unwrap();
append_buf.as_mut_slice().fill(1);
let pos = block_ring.append(append_buf.as_ref()).unwrap();
assert_eq!(pos, num_blocks);
assert_eq!(block_ring.nblocks(), num_blocks + 1);
let mut read_buf = Buf::alloc(1).unwrap();
block_ring
.read(pos % num_blocks, read_buf.as_mut())
.unwrap();
assert_eq!(read_buf.as_slice(), append_buf.as_slice());
}
}

View File

@ -0,0 +1,227 @@
// SPDX-License-Identifier: MPL-2.0
use core::ops::Range;
use inherit_methods_macro::inherit_methods;
use super::{Buf, BufMut, BufRef};
use crate::{error::Errno, os::Mutex, prelude::*};
/// A fixed set of data blocks that can support random reads and writes.
///
/// # Thread safety
///
/// `BlockSet` is a data structure of interior mutability.
/// It is ok to perform I/O on a `BlockSet` concurrently in multiple threads.
/// `BlockSet` promises the atomicity of reading and writing individual blocks.
pub trait BlockSet: Sync + Send {
/// Read one or multiple blocks at a specified position.
fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>;
/// Read a slice of bytes at a specified byte offset.
fn read_slice(&self, offset: usize, buf: &mut [u8]) -> Result<()> {
let start_pos = offset / BLOCK_SIZE;
let end_pos = (offset + buf.len()).div_ceil(BLOCK_SIZE);
if end_pos > self.nblocks() {
return_errno_with_msg!(Errno::InvalidArgs, "read_slice position is out of range");
}
let nblocks = end_pos - start_pos;
let mut blocks = Buf::alloc(nblocks)?;
self.read(start_pos, blocks.as_mut())?;
let offset = offset % BLOCK_SIZE;
buf.copy_from_slice(&blocks.as_slice()[offset..offset + buf.len()]);
Ok(())
}
/// Write one or multiple blocks at a specified position.
fn write(&self, pos: BlockId, buf: BufRef) -> Result<()>;
/// Write a slice of bytes at a specified byte offset.
fn write_slice(&self, offset: usize, buf: &[u8]) -> Result<()> {
let start_pos = offset / BLOCK_SIZE;
let end_pos = (offset + buf.len()).div_ceil(BLOCK_SIZE);
if end_pos > self.nblocks() {
return_errno_with_msg!(Errno::InvalidArgs, "write_slice position is out of range");
}
let nblocks = end_pos - start_pos;
let mut blocks = Buf::alloc(nblocks)?;
// Maybe we should read the first block partially.
let start_offset = offset % BLOCK_SIZE;
if start_offset != 0 {
let mut start_block = Buf::alloc(1)?;
self.read(start_pos, start_block.as_mut())?;
blocks.as_mut_slice()[..start_offset]
.copy_from_slice(&start_block.as_slice()[..start_offset]);
}
// Copy the input buffer to the write buffer.
let end_offset = start_offset + buf.len();
blocks.as_mut_slice()[start_offset..end_offset].copy_from_slice(buf);
// Maybe we should read the last block partially.
if end_offset % BLOCK_SIZE != 0 {
let mut end_block = Buf::alloc(1)?;
self.read(end_pos, end_block.as_mut())?;
blocks.as_mut_slice()[end_offset..]
.copy_from_slice(&end_block.as_slice()[end_offset % BLOCK_SIZE..]);
}
// Write blocks.
self.write(start_pos, blocks.as_ref())?;
Ok(())
}
/// Get a subset of the blocks in the block set.
fn subset(&self, range: Range<BlockId>) -> Result<Self>
where
Self: Sized;
/// Ensure that blocks are persisted to the disk.
fn flush(&self) -> Result<()>;
/// Returns the number of blocks.
fn nblocks(&self) -> usize;
}
macro_rules! impl_blockset_for {
($typ:ty,$from:tt,$subset_fn:expr) => {
#[inherit_methods(from = $from)]
impl<T: BlockSet> BlockSet for $typ {
fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>;
fn read_slice(&self, offset: usize, buf: &mut [u8]) -> Result<()>;
fn write(&self, pos: BlockId, buf: BufRef) -> Result<()>;
fn write_slice(&self, offset: usize, buf: &[u8]) -> Result<()>;
fn flush(&self) -> Result<()>;
fn nblocks(&self) -> usize;
fn subset(&self, range: Range<BlockId>) -> Result<Self> {
let closure = $subset_fn;
closure(self, range)
}
}
};
}
impl_blockset_for!(&T, "(**self)", |_this, _range| {
return_errno_with_msg!(Errno::NotFound, "cannot return `Self` by `subset` of `&T`");
});
impl_blockset_for!(&mut T, "(**self)", |_this, _range| {
return_errno_with_msg!(
Errno::NotFound,
"cannot return `Self` by `subset` of `&mut T`"
);
});
impl_blockset_for!(Box<T>, "(**self)", |this: &T, range| {
this.subset(range).map(|v| Box::new(v))
});
impl_blockset_for!(Arc<T>, "(**self)", |this: &Arc<T>, range| {
(**this).subset(range).map(|v| Arc::new(v))
});
/// A disk that impl `BlockSet`.
///
/// The `region` is the accessible subset.
#[derive(Clone)]
pub struct MemDisk {
disk: Arc<Mutex<Buf>>,
region: Range<BlockId>,
}
impl MemDisk {
/// Create a `MemDisk` with the number of blocks.
pub fn create(num_blocks: usize) -> Result<Self> {
let blocks = Buf::alloc(num_blocks)?;
Ok(Self {
disk: Arc::new(Mutex::new(blocks)),
region: Range {
start: 0,
end: num_blocks,
},
})
}
}
impl BlockSet for MemDisk {
fn read(&self, pos: BlockId, mut buf: BufMut) -> Result<()> {
if pos + buf.nblocks() > self.region.end {
return_errno_with_msg!(Errno::InvalidArgs, "read position is out of range");
}
let offset = (self.region.start + pos) * BLOCK_SIZE;
let buf_len = buf.as_slice().len();
let disk = self.disk.lock();
buf.as_mut_slice()
.copy_from_slice(&disk.as_slice()[offset..offset + buf_len]);
Ok(())
}
fn write(&self, pos: BlockId, buf: BufRef) -> Result<()> {
if pos + buf.nblocks() > self.region.end {
return_errno_with_msg!(Errno::InvalidArgs, "write position is out of range");
}
let offset = (self.region.start + pos) * BLOCK_SIZE;
let buf_len = buf.as_slice().len();
let mut disk = self.disk.lock();
disk.as_mut_slice()[offset..offset + buf_len].copy_from_slice(buf.as_slice());
Ok(())
}
fn subset(&self, range: Range<BlockId>) -> Result<Self> {
if self.region.start + range.end > self.region.end {
return_errno_with_msg!(Errno::InvalidArgs, "subset is out of range");
}
Ok(MemDisk {
disk: self.disk.clone(),
region: Range {
start: self.region.start + range.start,
end: self.region.start + range.end,
},
})
}
fn flush(&self) -> Result<()> {
Ok(())
}
fn nblocks(&self) -> usize {
self.region.len()
}
}
#[cfg(test)]
mod tests {
use core::ops::Range;
use crate::layers::bio::{BlockSet, Buf, MemDisk};
#[test]
fn mem_disk() {
let num_blocks = 64;
let disk = MemDisk::create(num_blocks).unwrap();
assert_eq!(disk.nblocks(), 64);
let mut buf = Buf::alloc(1).unwrap();
buf.as_mut_slice().fill(1);
disk.write(32, buf.as_ref()).unwrap();
let range = Range { start: 32, end: 64 };
let subset = disk.subset(range).unwrap();
assert_eq!(subset.nblocks(), 32);
buf.as_mut_slice().fill(0);
subset.read(0, buf.as_mut()).unwrap();
assert_eq!(buf.as_ref().as_slice(), [1u8; 4096]);
subset.write_slice(4096 - 4, &[2u8; 8]).unwrap();
let mut buf = [0u8; 16];
subset.read_slice(4096 - 8, &mut buf).unwrap();
assert_eq!(buf, [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]);
}
}

View File

@ -0,0 +1,24 @@
// SPDX-License-Identifier: MPL-2.0
//! The layer of untrusted block I/O.
use static_assertions::assert_eq_size;
mod block_buf;
mod block_log;
mod block_ring;
mod block_set;
pub use self::{
block_buf::{Buf, BufMut, BufRef},
block_log::{BlockLog, MemLog},
block_ring::BlockRing,
block_set::{BlockSet, MemDisk},
};
pub type BlockId = usize;
pub const BLOCK_SIZE: usize = 0x1000;
pub const BID_SIZE: usize = core::mem::size_of::<BlockId>();
// This definition of BlockId assumes the target architecture is 64-bit
assert_eq_size!(usize, u64);

View File

@ -0,0 +1,358 @@
// SPDX-License-Identifier: MPL-2.0
use ostd_pod::Pod;
use super::{Iv, Key, Mac, VersionId};
use crate::{
layers::bio::{BlockSet, Buf, BLOCK_SIZE},
os::{Aead, Mutex},
prelude::*,
};
/// A cryptographically-protected blob of user data.
///
/// `CryptoBlob<B>` allows a variable-length of user data to be securely
/// written to and read from a fixed, pre-allocated block set
/// (represented by `B: BlockSet`) on disk. Obviously, the length of user data
/// must be smaller than that of the block set.
///
/// # On-disk format
///
/// The on-disk format of `CryptoBlob` is shown below.
///
/// ```
/// ┌─────────┬─────────┬─────────┬──────────────────────────────┐
/// │VersionId│ MAC │ Length │ Encrypted Payload │
/// │ (8B) │ (16B) │ (8B) │ (Length bytes) │
/// └─────────┴─────────┴─────────┴──────────────────────────────┘
/// ```
///
/// The version ID increments by one each time the `CryptoBlob` is updated.
/// The MAC protects the integrity of the length and the encrypted payload.
///
/// # Security
///
/// To ensure the confidentiality and integrity of user data, `CryptoBlob`
/// takes several measures:
///
/// 1. Each instance of `CryptoBlob` is associated with a randomly-generated,
/// unique encryption key.
/// 2. Each instance of `CryptoBlob` maintains a version ID, which is
/// automatically incremented by one upon each write.
/// 3. The user data written to a `CryptoBlob` is protected with authenticated
/// encryption before being persisted to the disk.
/// The encryption takes the current version ID as the IV and generates a MAC
/// as the output.
/// 4. To read user data from a `CryptoBlob`, it first decrypts
/// the untrusted on-disk data with the encryption key associated with this object
/// and validating its integrity. Optimally, the user can check the version ID
/// of the decrypted user data and see if the version ID is up-to-date.
///
pub struct CryptoBlob<B> {
block_set: B,
key: Key,
header: Mutex<Option<Header>>,
}
#[repr(C)]
#[derive(Copy, Clone, Pod)]
struct Header {
version: VersionId,
mac: Mac,
payload_len: usize,
}
impl<B: BlockSet> CryptoBlob<B> {
/// The size of the header of a crypto blob in bytes.
pub const HEADER_NBYTES: usize = core::mem::size_of::<Header>();
/// Opens an existing `CryptoBlob`.
///
/// The capacity of this `CryptoBlob` object is determined by the size
/// of `block_set: B`.
pub fn open(key: Key, block_set: B) -> Self {
Self {
block_set,
key,
header: Mutex::new(None),
}
}
/// Creates a new `CryptoBlob`.
///
/// The encryption key of a `CryptoBlob` is generated randomly so that
/// no two `CryptoBlob` instances shall ever use the same key.
pub fn create(block_set: B, init_data: &[u8]) -> Result<Self> {
let capacity = block_set.nblocks() * BLOCK_SIZE - Self::HEADER_NBYTES;
if init_data.len() > capacity {
return_errno_with_msg!(OutOfDisk, "init_data is too large");
}
let nblocks = (Self::HEADER_NBYTES + init_data.len()).div_ceil(BLOCK_SIZE);
let mut block_buf = Buf::alloc(nblocks)?;
// Encrypt init_data.
let aead = Aead::new();
let key = Key::random();
let version: VersionId = 0;
let mut iv = Iv::new_zeroed();
iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes());
let output = &mut block_buf.as_mut_slice()
[Self::HEADER_NBYTES..Self::HEADER_NBYTES + init_data.len()];
let mac = aead.encrypt(init_data, &key, &iv, &[], output)?;
// Store header.
let header = Header {
version,
mac,
payload_len: init_data.len(),
};
block_buf.as_mut_slice()[..Self::HEADER_NBYTES].copy_from_slice(header.as_bytes());
// Write to `BlockSet`.
block_set.write(0, block_buf.as_ref())?;
Ok(Self {
block_set,
key,
header: Mutex::new(Some(header)),
})
}
/// Write the buffer to the disk as the latest version of the content of
/// this `CryptoBlob`.
///
/// The size of the buffer must not be greater than the capacity of this
/// `CryptoBlob`.
///
/// Each successful write increments the version ID by one. If
/// there is no valid version ID, an `Error` will be returned.
/// User could get a version ID, either by a successful call to
/// `read`, or `recover_from` another valid `CryptoBlob`.
///
/// # Security
///
/// This content is guaranteed to be confidential as long as the key is not
/// known to an attacker.
pub fn write(&mut self, buf: &[u8]) -> Result<VersionId> {
if buf.len() > self.capacity() {
return_errno_with_msg!(OutOfDisk, "write data is too large");
}
let nblocks = (Self::HEADER_NBYTES + buf.len()).div_ceil(BLOCK_SIZE);
let mut block_buf = Buf::alloc(nblocks)?;
// Encrypt payload.
let aead = Aead::new();
let version = match self.version_id() {
Some(version) => version + 1,
None => return_errno_with_msg!(NotFound, "write with no valid version ID"),
};
let mut iv = Iv::new_zeroed();
iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes());
let output =
&mut block_buf.as_mut_slice()[Self::HEADER_NBYTES..Self::HEADER_NBYTES + buf.len()];
let mac = aead.encrypt(buf, &self.key, &iv, &[], output)?;
// Store header.
let header = Header {
version,
mac,
payload_len: buf.len(),
};
block_buf.as_mut_slice()[..Self::HEADER_NBYTES].copy_from_slice(header.as_bytes());
// Write to `BlockSet`.
self.block_set.write(0, block_buf.as_ref())?;
*self.header.lock() = Some(header);
Ok(version)
}
/// Read the content of the `CryptoBlob` from the disk into the buffer.
///
/// The given buffer must has a length that is no less than the size of
/// the plaintext content of this `CryptoBlob`.
///
/// # Security
///
/// This content, including its length, is guaranteed to be authentic.
pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
let header = match *self.header.lock() {
Some(header) => header,
None => {
let mut header = Header::new_zeroed();
self.block_set.read_slice(0, header.as_bytes_mut())?;
header
}
};
if header.payload_len > self.capacity() {
return_errno_with_msg!(OutOfDisk, "payload_len is greater than the capacity");
}
if header.payload_len > buf.len() {
return_errno_with_msg!(OutOfDisk, "read_buf is too small");
}
let nblock = (Self::HEADER_NBYTES + header.payload_len).div_ceil(BLOCK_SIZE);
let mut block_buf = Buf::alloc(nblock)?;
self.block_set.read(0, block_buf.as_mut())?;
// Decrypt payload.
let aead = Aead::new();
let version = header.version;
let mut iv = Iv::new_zeroed();
iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes());
let input =
&block_buf.as_slice()[Self::HEADER_NBYTES..Self::HEADER_NBYTES + header.payload_len];
let output = &mut buf[..header.payload_len];
aead.decrypt(input, &self.key, &iv, &[], &header.mac, output)?;
*self.header.lock() = Some(header);
Ok(header.payload_len)
}
/// Returns the key associated with this `CryptoBlob`.
pub fn key(&self) -> &Key {
&self.key
}
/// Returns the current version ID.
///
/// # Security
///
/// It is valid after a successful call to `create`, `read` or `write`.
/// User could also get a version ID from another valid `CryptoBlob`,
/// (usually a backup), through method `recover_from`.
pub fn version_id(&self) -> Option<VersionId> {
self.header.lock().map(|header| header.version)
}
/// Recover from another `CryptoBlob`.
///
/// If `CryptoBlob` doesn't have a valid version ID, e.g., payload decryption
/// failed when `read`, user could call this method to recover version ID and
/// payload from another `CryptoBlob` (usually a backup).
pub fn recover_from(&mut self, other: &CryptoBlob<B>) -> Result<()> {
if self.capacity() != other.capacity() {
return_errno_with_msg!(InvalidArgs, "capacity not aligned, recover failed");
}
if self.header.lock().is_some() {
return_errno_with_msg!(InvalidArgs, "no need to recover");
}
let nblocks = self.block_set.nblocks();
// Read version ID and payload from another `CryptoBlob`.
let mut read_buf = Buf::alloc(nblocks)?;
let payload_len = other.read(read_buf.as_mut_slice())?;
let version = other.version_id().unwrap();
// Encrypt payload.
let aead = Aead::new();
let mut iv = Iv::new_zeroed();
iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes());
let input = &read_buf.as_slice()[..payload_len];
let mut write_buf = Buf::alloc(nblocks)?;
let output =
&mut write_buf.as_mut_slice()[Self::HEADER_NBYTES..Self::HEADER_NBYTES + payload_len];
let mac = aead.encrypt(input, self.key(), &iv, &[], output)?;
// Store header.
let header = Header {
version,
mac,
payload_len,
};
write_buf.as_mut_slice()[..Self::HEADER_NBYTES].copy_from_slice(header.as_bytes());
// Write to `BlockSet`.
self.block_set.write(0, write_buf.as_ref())?;
*self.header.lock() = Some(header);
Ok(())
}
/// Returns the current MAC of encrypted payload.
///
/// # Security
///
/// It is valid after a successful call to `create`, `read` or `write`.
pub fn current_mac(&self) -> Option<Mac> {
self.header.lock().map(|header| header.mac)
}
/// Returns the capacity of this `CryptoBlob` in bytes.
pub fn capacity(&self) -> usize {
self.block_set.nblocks() * BLOCK_SIZE - Self::HEADER_NBYTES
}
/// Returns the number of blocks occupied by the underlying `BlockSet`.
pub fn nblocks(&self) -> usize {
self.block_set.nblocks()
}
}
#[cfg(test)]
mod tests {
use super::CryptoBlob;
use crate::layers::bio::{BlockSet, MemDisk, BLOCK_SIZE};
#[test]
fn create() {
let disk = MemDisk::create(2).unwrap();
let init_data = [1u8; BLOCK_SIZE];
let blob = CryptoBlob::create(disk, &init_data).unwrap();
println!("blob key: {:?}", blob.key());
assert_eq!(blob.version_id(), Some(0));
assert_eq!(blob.nblocks(), 2);
assert_eq!(
blob.capacity(),
2 * BLOCK_SIZE - CryptoBlob::<MemDisk>::HEADER_NBYTES
);
}
#[test]
fn open_and_read() {
let disk = MemDisk::create(4).unwrap();
let key = {
let subset = disk.subset(0..2).unwrap();
let init_data = [1u8; 1024];
let blob = CryptoBlob::create(subset, &init_data).unwrap();
blob.key
};
let subset = disk.subset(0..2).unwrap();
let blob = CryptoBlob::open(key, subset);
assert_eq!(blob.version_id(), None);
assert_eq!(blob.nblocks(), 2);
let mut buf = [0u8; BLOCK_SIZE];
let payload_len = blob.read(&mut buf).unwrap();
assert_eq!(buf[..payload_len], [1u8; 1024]);
}
#[test]
fn write() {
let disk = MemDisk::create(2).unwrap();
let init_data = [0u8; BLOCK_SIZE];
let mut blob = CryptoBlob::create(disk, &init_data).unwrap();
let write_buf = [1u8; 1024];
blob.write(&write_buf).unwrap();
let mut read_buf = [0u8; 1024];
blob.read(&mut read_buf).unwrap();
assert_eq!(read_buf, [1u8; 1024]);
assert_eq!(blob.version_id(), Some(1));
}
#[test]
fn recover_from() {
let disk = MemDisk::create(2).unwrap();
let init_data = [1u8; 1024];
let subset0 = disk.subset(0..1).unwrap();
let mut blob0 = CryptoBlob::create(subset0, &init_data).unwrap();
assert_eq!(blob0.version_id(), Some(0));
blob0.write(&init_data).unwrap();
assert_eq!(blob0.version_id(), Some(1));
let subset1 = disk.subset(1..2).unwrap();
let mut blob1 = CryptoBlob::open(blob0.key, subset1);
assert_eq!(blob1.version_id(), None);
blob1.recover_from(&blob0).unwrap();
let mut read_buf = [0u8; BLOCK_SIZE];
let payload_len = blob1.read(&mut read_buf).unwrap();
assert_eq!(read_buf[..payload_len], [1u8; 1024]);
assert_eq!(blob1.version_id(), Some(1));
}
}

View File

@ -0,0 +1,401 @@
// SPDX-License-Identifier: MPL-2.0
use core::ops::Range;
use lending_iterator::prelude::*;
use ostd_pod::Pod;
use super::{Iv, Key, Mac};
use crate::{
layers::bio::{BlockId, BlockLog, Buf, BLOCK_SIZE},
os::Aead,
prelude::*,
};
/// A cryptographically-protected chain of blocks.
///
/// `CryptoChain<L>` allows writing and reading a sequence of
/// consecutive blocks securely to and from an untrusted storage of data log
/// `L: BlockLog`.
/// The target use case of `CryptoChain` is to implement secure journals,
/// where old data are scanned and new data are appended.
///
/// # On-disk format
///
/// The on-disk format of each block is shown below.
///
/// ```text
/// ┌─────────────────────┬───────┬──────────┬──────────┬──────────┬─────────┐
/// │ Encrypted payload │ Gap │ Length │ PreMac │ CurrMac │ IV │
/// │(Length <= 4KB - 48B)│ │ (4B) │ (16B) │ (16B) │ (12B) │
/// └─────────────────────┴───────┴──────────┴──────────┴──────────┴─────────┘
///
/// ◄─────────────────────────── Block size (4KB) ──────────────────────────►
/// ```
///
/// Each block begins with encrypted user payload. The size of payload
/// must be smaller than that of block size as each block ends with a footer
/// (in plaintext).
/// The footer consists of fours parts: the length of the payload (in bytes),
/// the MAC of the previous block, the MAC of the current block, the IV used
/// for encrypting the current block.
/// The MAC of a block protects the encrypted payload, its length, and the MAC
/// of the previous block.
///
/// # Security
///
/// Each `CryptoChain` is assigned a randomly-generated encryption key.
/// Each block is encrypted using this key and a randomly-generated IV.
/// This setup ensures the confidentiality of payload and even the same payloads
/// result in different ciphertexts.
///
/// `CryptoChain` is called a "chain" of blocks because each block
/// not only stores its own MAC, but also the MAC of its previous block.
/// This effectively forms a "chain" (much like a blockchain),
/// ensuring the orderness and consecutiveness of the sequence of blocks.
///
/// Due to this chain structure, the integrity of a `CryptoChain` can be ensured
/// by verifying the MAC of the last block. Once the integrity of the last block
/// is verified, the integrity of all previous blocks can also be verified.
pub struct CryptoChain<L> {
block_log: L,
key: Key,
block_range: Range<BlockId>,
block_macs: Vec<Mac>,
}
#[repr(C)]
#[derive(Copy, Clone, Pod)]
struct Footer {
len: u32,
pre_mac: Mac,
this_mac: Mac,
this_iv: Iv,
}
impl<L: BlockLog> CryptoChain<L> {
/// The available size in each chained block is smaller than that of
/// the block size.
pub const AVAIL_BLOCK_SIZE: usize = BLOCK_SIZE - core::mem::size_of::<Footer>();
/// Construct a new `CryptoChain` using `block_log: L` as the storage.
pub fn new(block_log: L) -> Self {
Self {
block_log,
block_range: 0..0,
key: Key::random(),
block_macs: Vec::new(),
}
}
/// Recover an existing `CryptoChain` backed by `block_log: L`,
/// starting from its `from` block.
pub fn recover(key: Key, block_log: L, from: BlockId) -> Recovery<L> {
Recovery::new(block_log, key, from)
}
/// Read a block at a specified position.
///
/// The length of the given buffer should not be smaller than payload_len
/// stored in `Footer`.
///
/// # Security
///
/// The authenticity of the block is guaranteed.
pub fn read(&self, pos: BlockId, buf: &mut [u8]) -> Result<usize> {
if !self.block_range().contains(&pos) {
return_errno_with_msg!(NotFound, "read position is out of range");
}
// Read block and get footer.
let mut block_buf = Buf::alloc(1)?;
self.block_log.read(pos, block_buf.as_mut())?;
let footer: Footer = Pod::from_bytes(&block_buf.as_slice()[Self::AVAIL_BLOCK_SIZE..]);
let payload_len = footer.len as usize;
if payload_len > Self::AVAIL_BLOCK_SIZE || payload_len > buf.len() {
return_errno_with_msg!(OutOfDisk, "wrong payload_len or the read_buf is too small");
}
// Check the footer MAC, to ensure the orderness and consecutiveness of blocks.
let this_mac = self.block_macs.get(pos - self.block_range.start).unwrap();
if footer.this_mac.as_bytes() != this_mac.as_bytes() {
return_errno_with_msg!(NotFound, "check footer MAC failed");
}
// Decrypt payload.
let aead = Aead::new();
aead.decrypt(
&block_buf.as_slice()[..payload_len],
self.key(),
&footer.this_iv,
&footer.pre_mac,
&footer.this_mac,
&mut buf[..payload_len],
)?;
Ok(payload_len)
}
/// Append a block at the end.
///
/// The length of the given buffer must not be larger than `AVAIL_BLOCK_SIZE`.
///
/// # Security
///
/// The confidentiality of the block is guaranteed.
pub fn append(&mut self, buf: &[u8]) -> Result<()> {
if buf.len() > Self::AVAIL_BLOCK_SIZE {
return_errno_with_msg!(OutOfDisk, "append data is too large");
}
let mut block_buf = Buf::alloc(1)?;
// Encrypt payload.
let aead = Aead::new();
let this_iv = Iv::random();
let pre_mac = self.block_macs.last().copied().unwrap_or_default();
let output = &mut block_buf.as_mut_slice()[..buf.len()];
let this_mac = aead.encrypt(buf, self.key(), &this_iv, &pre_mac, output)?;
// Store footer.
let footer = Footer {
len: buf.len() as _,
pre_mac,
this_mac,
this_iv,
};
let buf = &mut block_buf.as_mut_slice()[Self::AVAIL_BLOCK_SIZE..];
buf.copy_from_slice(footer.as_bytes());
self.block_log.append(block_buf.as_ref())?;
self.block_range.end += 1;
self.block_macs.push(this_mac);
Ok(())
}
/// Ensures the persistence of data.
pub fn flush(&self) -> Result<()> {
self.block_log.flush()
}
/// Trim the blocks before a specified position (exclusive).
///
/// The purpose of this method is to free some memory used for keeping the
/// MACs of accessible blocks. After trimming, the range of accessible
/// blocks is shrunk accordingly.
pub fn trim(&mut self, before_block: BlockId) {
// We must ensure the invariance that there is at least one valid block
// after trimming.
debug_assert!(before_block < self.block_range.end);
if before_block <= self.block_range.start {
return;
}
let num_blocks_trimmed = before_block - self.block_range.start;
self.block_range.start = before_block;
self.block_macs.drain(..num_blocks_trimmed);
}
/// Returns the range of blocks that are accessible through the `CryptoChain`.
pub fn block_range(&self) -> &Range<BlockId> {
&self.block_range
}
/// Returns the underlying block log.
pub fn inner_log(&self) -> &L {
&self.block_log
}
/// Returns the encryption key of the `CryptoChain`.
pub fn key(&self) -> &Key {
&self.key
}
}
/// `Recovery<L>` represents an instance `CryptoChain<L>` being recovered.
///
/// An object `Recovery<L>` attempts to recover as many valid blocks of
/// a `CryptoChain` as possible. A block is valid if and only if its real MAC
/// is equal to the MAC value recorded in its successor.
///
/// For the last block, which does not have a successor block, the user
/// can obtain its MAC from `Recovery<L>` and verify the MAC by comparing it
/// with an expected value from another trusted source.
pub struct Recovery<L> {
block_log: L,
key: Key,
block_range: Range<BlockId>,
block_macs: Vec<Mac>,
read_buf: Buf,
payload: Buf,
}
impl<L: BlockLog> Recovery<L> {
/// Construct a new `Recovery` from the `first_block` of
/// `block_log: L`, using a cryptographic `key`.
pub fn new(block_log: L, key: Key, first_block: BlockId) -> Self {
Self {
block_log,
key,
block_range: first_block..first_block,
block_macs: Vec::new(),
read_buf: Buf::alloc(1).unwrap(),
payload: Buf::alloc(1).unwrap(),
}
}
/// Returns the number of valid blocks.
///
/// Each success call to `next` increments the number of valid blocks.
pub fn num_blocks(&self) -> usize {
self.block_range.len()
}
/// Returns the range of valid blocks.
///
/// Each success call to `next` increments the upper bound by one.
pub fn block_range(&self) -> &Range<BlockId> {
&self.block_range
}
/// Returns the MACs of valid blocks.
///
/// Each success call to `next` pushes the MAC of the new valid block.
pub fn block_macs(&self) -> &[Mac] {
&self.block_macs
}
/// Open a `CryptoChain<L>` from the recovery object.
///
/// User should call `next` to retrieve valid blocks as much as possible.
pub fn open(self) -> CryptoChain<L> {
CryptoChain {
block_log: self.block_log,
key: self.key,
block_range: self.block_range,
block_macs: self.block_macs,
}
}
}
#[gat]
impl<L: BlockLog> LendingIterator for Recovery<L> {
type Item<'a> = &'a [u8];
fn next(&mut self) -> Option<Self::Item<'_>> {
let next_block_id = self.block_range.end;
self.block_log
.read(next_block_id, self.read_buf.as_mut())
.ok()?;
// Deserialize footer.
let footer: Footer =
Pod::from_bytes(&self.read_buf.as_slice()[CryptoChain::<L>::AVAIL_BLOCK_SIZE..]);
let payload_len = footer.len as usize;
if payload_len > CryptoChain::<L>::AVAIL_BLOCK_SIZE {
return None;
}
// Decrypt payload.
let aead = Aead::new();
aead.decrypt(
&self.read_buf.as_slice()[..payload_len],
&self.key,
&footer.this_iv,
&footer.pre_mac,
&footer.this_mac,
&mut self.payload.as_mut_slice()[..payload_len],
)
.ok()?;
// Crypto blocks are chained: each block stores not only
// the MAC of its own, but also the MAC of its previous block.
// So we need to check whether the two MAC values are the same.
// There is one exception that the `pre_mac` of the first block
// is NOT checked.
if self
.block_macs()
.last()
.is_some_and(|mac| mac.as_bytes() != footer.pre_mac.as_bytes())
{
return None;
}
self.block_range.end += 1;
self.block_macs.push(footer.this_mac);
Some(&self.payload.as_slice()[..payload_len])
}
}
#[cfg(test)]
mod tests {
use lending_iterator::LendingIterator;
use super::CryptoChain;
use crate::layers::bio::{BlockLog, BlockRing, BlockSet, MemDisk};
#[test]
fn new() {
let disk = MemDisk::create(16).unwrap();
let block_ring = BlockRing::new(disk);
block_ring.set_cursor(0);
let chain = CryptoChain::new(block_ring);
assert_eq!(chain.block_log.nblocks(), 0);
assert_eq!(chain.block_range.start, 0);
assert_eq!(chain.block_range.end, 0);
assert_eq!(chain.block_macs.len(), 0);
}
#[test]
fn append_trim_and_read() {
let disk = MemDisk::create(16).unwrap();
let block_ring = BlockRing::new(disk);
block_ring.set_cursor(0);
let mut chain = CryptoChain::new(block_ring);
let data = [1u8; 1024];
chain.append(&data[..256]).unwrap();
chain.append(&data[..512]).unwrap();
assert_eq!(chain.block_range.end, 2);
assert_eq!(chain.block_macs.len(), 2);
chain.trim(1);
assert_eq!(chain.block_range.start, 1);
assert_eq!(chain.block_range.end, 2);
assert_eq!(chain.block_macs.len(), 1);
let mut buf = [0u8; 1024];
let len = chain.read(1, &mut buf).unwrap();
assert_eq!(len, 512);
assert_eq!(buf[..512], [1u8; 512]);
}
#[test]
fn recover() {
let disk = MemDisk::create(16).unwrap();
let key = {
let sub_disk = disk.subset(0..8).unwrap();
let block_ring = BlockRing::new(sub_disk);
block_ring.set_cursor(0);
let data = [1u8; 1024];
let mut chain = CryptoChain::new(block_ring);
for _ in 0..4 {
chain.append(&data).unwrap();
}
chain.flush().unwrap();
chain.key
};
let sub_disk = disk.subset(0..8).unwrap();
let block_ring = BlockRing::new(sub_disk);
let mut recover = CryptoChain::recover(key, block_ring, 2);
while let Some(payload) = recover.next() {
assert_eq!(payload.len(), 1024);
}
let chain = recover.open();
assert_eq!(chain.block_range(), &(2..4));
assert_eq!(chain.block_macs.len(), 2);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
// SPDX-License-Identifier: MPL-2.0
//! The layer of cryptographical constructs.
mod crypto_blob;
mod crypto_chain;
mod crypto_log;
pub use self::{
crypto_blob::CryptoBlob,
crypto_chain::CryptoChain,
crypto_log::{CryptoLog, NodeCache, RootMhtMeta},
};
pub type Key = crate::os::AeadKey;
pub type Iv = crate::os::AeadIv;
pub type Mac = crate::os::AeadMac;
pub type VersionId = u64;

View File

@ -0,0 +1,154 @@
// SPDX-License-Identifier: MPL-2.0
use core::marker::PhantomData;
use serde::{ser::SerializeSeq, Deserialize, Serialize};
use crate::prelude::*;
/// An edit of `Edit<S>` is an incremental change to a state of `S`.
pub trait Edit<S>: Serialize + for<'de> Deserialize<'de> {
/// Apply this edit to a state.
fn apply_to(&self, state: &mut S);
}
/// A group of edits to a state.
pub struct EditGroup<E: Edit<S>, S> {
edits: Vec<E>,
_s: PhantomData<S>,
}
impl<E: Edit<S>, S> EditGroup<E, S> {
/// Creates an empty edit group.
pub fn new() -> Self {
Self {
edits: Vec::new(),
_s: PhantomData,
}
}
/// Adds an edit to the group.
pub fn push(&mut self, edit: E) {
self.edits.push(edit);
}
/// Returns an iterator to the contained edits.
pub fn iter(&self) -> impl Iterator<Item = &E> {
self.edits.iter()
}
/// Clears the edit group by removing all contained edits.
pub fn clear(&mut self) {
self.edits.clear()
}
/// Returns whether the edit group contains no edits.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the length of the edit group.
pub fn len(&self) -> usize {
self.edits.len()
}
}
impl<E: Edit<S>, S> Edit<S> for EditGroup<E, S> {
fn apply_to(&self, state: &mut S) {
for edit in &self.edits {
edit.apply_to(state);
}
}
}
impl<E: Edit<S>, S> Serialize for EditGroup<E, S> {
fn serialize<Se>(&self, serializer: Se) -> core::result::Result<Se::Ok, Se::Error>
where
Se: serde::Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for edit in &self.edits {
seq.serialize_element(edit)?
}
seq.end()
}
}
impl<'de, E: Edit<S>, S> Deserialize<'de> for EditGroup<E, S> {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct EditsVisitor<E: Edit<S>, S> {
_p: PhantomData<(E, S)>,
}
impl<'a, E: Edit<S>, S> serde::de::Visitor<'a> for EditsVisitor<E, S> {
type Value = EditGroup<E, S>;
fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_str("an edit group")
}
fn visit_seq<A>(self, mut seq: A) -> core::result::Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'a>,
{
let mut edits = Vec::with_capacity(seq.size_hint().unwrap_or(0));
while let Some(e) = seq.next_element()? {
edits.push(e);
}
Ok(EditGroup {
edits,
_s: PhantomData,
})
}
}
deserializer.deserialize_seq(EditsVisitor { _p: PhantomData })
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::*;
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct XEdit {
x: i32,
}
struct XState {
sum: i32,
}
impl Edit<XState> for XEdit {
fn apply_to(&self, state: &mut XState) {
(*state).sum += self.x;
}
}
#[test]
fn serde_edit() {
let mut group = EditGroup::<XEdit, XState>::new();
let mut sum = 0;
for x in 0..10 {
sum += x;
let edit = XEdit { x };
group.push(edit);
}
let mut state = XState { sum: 0 };
group.apply_to(&mut state);
assert_eq!(state.sum, sum);
let mut buf = [0u8; 64];
let ser = postcard::to_slice(&group, buf.as_mut_slice()).unwrap();
println!("serialize len: {} data: {:?}", ser.len(), ser);
let de: EditGroup<XEdit, XState> = postcard::from_bytes(buf.as_slice()).unwrap();
println!("deserialize edits: {:?}", de.edits);
assert_eq!(de.len(), group.len());
assert_eq!(de.edits.as_slice(), group.edits.as_slice());
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
// SPDX-License-Identifier: MPL-2.0
//! The layer of edit journal.
mod edits;
mod journal;
pub use self::{
edits::{Edit, EditGroup},
journal::{
CompactPolicy, DefaultCompactPolicy, EditJournal, EditJournalMeta, NeverCompactPolicy,
},
};

View File

@ -0,0 +1,480 @@
// SPDX-License-Identifier: MPL-2.0
//! Chunk-based storage management.
//!
//! A chunk is a group of consecutive blocks.
//! As the size of a chunk is much greater than that of a block,
//! the number of chunks is naturally far smaller than that of blocks.
//! This makes it possible to keep all metadata for chunks in memory.
//! Thus, managing chunks is more efficient than managing blocks.
//!
//! The primary API provided by this module is chunk allocators,
//! `ChunkAlloc`, which tracks whether chunks are free or not.
//!
//! # Examples
//!
//! Chunk allocators are used within transactions.
//!
//! ```
//! fn alloc_chunks(chunk_alloc: &ChunkAlloc, num_chunks: usize) -> Option<Vec<ChunkId>> {
//! let mut tx = chunk_alloc.new_tx();
//! let res: Option<Vec<ChunkId>> = tx.context(|| {
//! let mut chunk_ids = Vec::new();
//! for _ in 0..num_chunks {
//! chunk_ids.push(chunk_alloc.alloc()?);
//! }
//! Some(chunk_ids)
//! });
//! if res.is_some() {
//! tx.commit().ok()?;
//! } else {
//! tx.abort();
//! }
//! res
//! }
//! ```
//!
//! This above example showcases the power of transaction atomicity:
//! if anything goes wrong (e.g., allocation failures) during the transaction,
//! then the transaction can be aborted and all changes made to `chuck_alloc`
//! during the transaction will be rolled back automatically.
use serde::{Deserialize, Serialize};
use crate::{
layers::edit::Edit,
os::{HashMap, Mutex},
prelude::*,
tx::{CurrentTx, TxData, TxProvider},
util::BitMap,
};
/// The ID of a chunk.
pub type ChunkId = usize;
/// Number of blocks of a chunk.
pub const CHUNK_NBLOCKS: usize = 1024;
/// The chunk size is a multiple of the block size.
pub const CHUNK_SIZE: usize = CHUNK_NBLOCKS * BLOCK_SIZE;
/// A chunk allocator tracks which chunks are free.
#[derive(Clone)]
pub struct ChunkAlloc {
state: Arc<Mutex<ChunkAllocState>>,
tx_provider: Arc<TxProvider>,
}
impl ChunkAlloc {
/// Creates a chunk allocator that manages a specified number of
/// chunks (`capacity`). Initially, all chunks are free.
pub fn new(capacity: usize, tx_provider: Arc<TxProvider>) -> Self {
let state = ChunkAllocState::new(capacity);
Self::from_parts(state, tx_provider)
}
/// Constructs a `ChunkAlloc` from its parts.
pub(super) fn from_parts(mut state: ChunkAllocState, tx_provider: Arc<TxProvider>) -> Self {
state.in_journal = false;
let new_self = Self {
state: Arc::new(Mutex::new(state)),
tx_provider,
};
// TX data
new_self
.tx_provider
.register_data_initializer(Box::new(ChunkAllocEdit::new));
// Commit handler
new_self.tx_provider.register_commit_handler({
let state = new_self.state.clone();
move |current: CurrentTx<'_>| {
let state = state.clone();
current.data_with(move |edit: &ChunkAllocEdit| {
if edit.edit_table.is_empty() {
return;
}
let mut state = state.lock();
edit.apply_to(&mut state);
});
}
});
// Abort handler
new_self.tx_provider.register_abort_handler({
let state = new_self.state.clone();
move |current: CurrentTx<'_>| {
let state = state.clone();
current.data_with(move |edit: &ChunkAllocEdit| {
let mut state = state.lock();
for chunk_id in edit.iter_allocated_chunks() {
state.dealloc(chunk_id);
}
});
}
});
new_self
}
/// Creates a new transaction for the chunk allocator.
pub fn new_tx(&self) -> CurrentTx<'_> {
self.tx_provider.new_tx()
}
/// Allocates a chunk, returning its ID.
pub fn alloc(&self) -> Option<ChunkId> {
let chunk_id = {
let mut state = self.state.lock();
state.alloc()? // Update global state immediately
};
let mut current_tx = self.tx_provider.current();
current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| {
edit.alloc(chunk_id);
});
Some(chunk_id)
}
/// Allocates `count` number of chunks. Returns IDs of newly-allocated
/// chunks, returns `None` if any allocation fails.
pub fn alloc_batch(&self, count: usize) -> Option<Vec<ChunkId>> {
let chunk_ids = {
let mut ids = Vec::with_capacity(count);
let mut state = self.state.lock();
for _ in 0..count {
match state.alloc() {
Some(id) => ids.push(id),
None => {
ids.iter().for_each(|id| state.dealloc(*id));
return None;
}
}
}
ids.sort_unstable();
ids
};
let mut current_tx = self.tx_provider.current();
current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| {
for chunk_id in &chunk_ids {
edit.alloc(*chunk_id);
}
});
Some(chunk_ids)
}
/// Deallocates the chunk of a given ID.
///
/// # Panic
///
/// Deallocating a free chunk causes panic.
pub fn dealloc(&self, chunk_id: ChunkId) {
let mut current_tx = self.tx_provider.current();
current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| {
let should_dealloc_now = edit.dealloc(chunk_id);
if should_dealloc_now {
let mut state = self.state.lock();
state.dealloc(chunk_id);
}
});
}
/// Deallocates the set of chunks of given IDs.
///
/// # Panic
///
/// Deallocating a free chunk causes panic.
pub fn dealloc_batch<I>(&self, chunk_ids: I)
where
I: Iterator<Item = ChunkId>,
{
let mut current_tx = self.tx_provider.current();
current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| {
let mut state = self.state.lock();
for chunk_id in chunk_ids {
let should_dealloc_now = edit.dealloc(chunk_id);
if should_dealloc_now {
state.dealloc(chunk_id);
}
}
});
}
/// Returns the capacity of the allocator, which is the number of chunks.
pub fn capacity(&self) -> usize {
self.state.lock().capacity()
}
/// Returns the number of free chunks.
pub fn free_count(&self) -> usize {
self.state.lock().free_count()
}
}
impl Debug for ChunkAlloc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.state.lock();
f.debug_struct("ChunkAlloc")
.field("bitmap_free_count", &state.free_count)
.field("bitmap_next_free", &state.next_free)
.finish()
}
}
////////////////////////////////////////////////////////////////////////////////
// Persistent State
////////////////////////////////////////////////////////////////////////////////
/// The persistent state of a chunk allocator.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChunkAllocState {
// A bitmap where each bit indicates whether a corresponding chunk
// has been allocated.
alloc_map: BitMap,
// The number of free chunks.
free_count: usize,
// The next free chunk Id. Used to narrow the scope of
// searching for free chunk IDs.
next_free: usize,
/// Whether the state is in the journal or not.
in_journal: bool,
}
// TODO: Separate persistent and volatile state of `ChunkAlloc`
impl ChunkAllocState {
/// Creates a persistent state for managing chunks of the specified number.
/// Initially, all chunks are free.
pub fn new(capacity: usize) -> Self {
Self {
alloc_map: BitMap::repeat(false, capacity),
free_count: capacity,
next_free: 0,
in_journal: false,
}
}
/// Creates a persistent state in the journal. The state in the journal and
/// the state that `RawLogStore` manages act differently on allocation and
/// edits' appliance.
pub fn new_in_journal(capacity: usize) -> Self {
Self {
alloc_map: BitMap::repeat(false, capacity),
free_count: capacity,
next_free: 0,
in_journal: true,
}
}
/// Allocates a chunk, returning its ID.
pub fn alloc(&mut self) -> Option<ChunkId> {
let mut next_free = self.next_free;
if next_free == self.alloc_map.len() {
next_free = 0;
}
let free_chunk_id = {
if let Some(chunk_id) = self.alloc_map.first_zero(next_free) {
chunk_id
} else {
self.alloc_map
.first_zero(0)
.expect("there must exists a zero")
}
};
self.alloc_map.set(free_chunk_id, true);
self.free_count -= 1;
self.next_free = free_chunk_id + 1;
Some(free_chunk_id)
}
/// Deallocates the chunk of a given ID.
///
/// # Panic
///
/// Deallocating a free chunk causes panic.
pub fn dealloc(&mut self, chunk_id: ChunkId) {
debug_assert!(self.alloc_map[chunk_id]);
self.alloc_map.set(chunk_id, false);
self.free_count += 1;
}
/// Returns the total number of chunks.
pub fn capacity(&self) -> usize {
self.alloc_map.len()
}
/// Returns the number of free chunks.
pub fn free_count(&self) -> usize {
self.free_count
}
/// Returns whether a specific chunk is allocated.
pub fn is_chunk_allocated(&self, chunk_id: ChunkId) -> bool {
self.alloc_map[chunk_id]
}
}
////////////////////////////////////////////////////////////////////////////////
// Persistent Edit
////////////////////////////////////////////////////////////////////////////////
/// A persistent edit to the state of a chunk allocator.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChunkAllocEdit {
edit_table: HashMap<ChunkId, ChunkEdit>,
}
/// The smallest unit of a persistent edit to the
/// state of a chunk allocator, which is
/// a chunk being either allocated or deallocated.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
enum ChunkEdit {
Alloc,
Dealloc,
}
impl ChunkAllocEdit {
/// Creates a new empty edit table.
pub fn new() -> Self {
Self {
edit_table: HashMap::new(),
}
}
/// Records a chunk allocation in the edit.
pub fn alloc(&mut self, chunk_id: ChunkId) {
let old_edit = self.edit_table.insert(chunk_id, ChunkEdit::Alloc);
// There must be a logical error if an edit has been recorded
// for the chunk. If the chunk edit is `ChunkEdit::Alloc`, then
// it is double allocations. If the chunk edit is `ChunkEdit::Dealloc`,
// then such deallocations can only take effect after the edit is
// committed. Thus, it is impossible to allocate the chunk again now.
assert!(old_edit.is_none());
}
/// Records a chunk deallocation in the edit.
///
/// The return value indicates whether the chunk being deallocated
/// is previously recorded in the edit as being allocated.
/// If so, the chunk can be deallocated in the `ChunkAllocState`.
pub fn dealloc(&mut self, chunk_id: ChunkId) -> bool {
match self.edit_table.get(&chunk_id) {
None => {
self.edit_table.insert(chunk_id, ChunkEdit::Dealloc);
false
}
Some(&ChunkEdit::Alloc) => {
self.edit_table.remove(&chunk_id);
true
}
Some(&ChunkEdit::Dealloc) => {
panic!("a chunk must not be deallocated twice");
}
}
}
/// Returns an iterator over all allocated chunks.
pub fn iter_allocated_chunks(&self) -> impl Iterator<Item = ChunkId> + '_ {
self.edit_table.iter().filter_map(|(id, edit)| {
if *edit == ChunkEdit::Alloc {
Some(*id)
} else {
None
}
})
}
pub fn is_empty(&self) -> bool {
self.edit_table.is_empty()
}
}
impl Edit<ChunkAllocState> for ChunkAllocEdit {
fn apply_to(&self, state: &mut ChunkAllocState) {
let mut to_be_deallocated = Vec::new();
for (&chunk_id, chunk_edit) in &self.edit_table {
match chunk_edit {
ChunkEdit::Alloc => {
if state.in_journal {
let _allocated_id = state.alloc().unwrap();
}
// Except journal, nothing needs to be done
}
ChunkEdit::Dealloc => {
to_be_deallocated.push(chunk_id);
}
}
}
for chunk_id in to_be_deallocated {
state.dealloc(chunk_id);
}
}
}
impl TxData for ChunkAllocEdit {}
#[cfg(test)]
mod tests {
use super::*;
fn new_chunk_alloc() -> ChunkAlloc {
let cap = 1024_usize;
let tx_provider = TxProvider::new();
let chunk_alloc = ChunkAlloc::new(cap, tx_provider);
assert_eq!(chunk_alloc.capacity(), cap);
assert_eq!(chunk_alloc.free_count(), cap);
chunk_alloc
}
fn do_alloc_dealloc_tx(chunk_alloc: &ChunkAlloc, alloc_cnt: usize, dealloc_cnt: usize) -> Tx {
debug_assert!(alloc_cnt <= chunk_alloc.capacity() && dealloc_cnt <= alloc_cnt);
let mut tx = chunk_alloc.new_tx();
tx.context(|| {
let chunk_id = chunk_alloc.alloc().unwrap();
let chunk_ids = chunk_alloc.alloc_batch(alloc_cnt - 1).unwrap();
let allocated_chunk_ids: Vec<ChunkId> = core::iter::once(chunk_id)
.chain(chunk_ids.into_iter())
.collect();
chunk_alloc.dealloc(allocated_chunk_ids[0]);
chunk_alloc.dealloc_batch(
allocated_chunk_ids[alloc_cnt - dealloc_cnt + 1..alloc_cnt]
.iter()
.cloned(),
);
});
tx
}
#[test]
fn chunk_alloc_dealloc_tx_commit() -> Result<()> {
let chunk_alloc = new_chunk_alloc();
let cap = chunk_alloc.capacity();
let (alloc_cnt, dealloc_cnt) = (cap, cap);
let mut tx = do_alloc_dealloc_tx(&chunk_alloc, alloc_cnt, dealloc_cnt);
tx.commit()?;
assert_eq!(chunk_alloc.free_count(), cap - alloc_cnt + dealloc_cnt);
Ok(())
}
#[test]
fn chunk_alloc_dealloc_tx_abort() -> Result<()> {
let chunk_alloc = new_chunk_alloc();
let cap = chunk_alloc.capacity();
let (alloc_cnt, dealloc_cnt) = (cap / 2, cap / 4);
let mut tx = do_alloc_dealloc_tx(&chunk_alloc, alloc_cnt, dealloc_cnt);
tx.abort();
assert_eq!(chunk_alloc.free_count(), cap);
Ok(())
}
}

View File

@ -0,0 +1,14 @@
// SPDX-License-Identifier: MPL-2.0
//! The layer of transactional logging.
//!
//! `TxLogStore` is a transactional, log-oriented file system.
//! It supports creating, deleting, listing, reading, and writing `TxLog`s.
//! Each `TxLog` is an append-only log, and assigned an unique `TxLogId`.
//! All `TxLogStore`'s APIs should be called within transactions (`TX`).
mod chunk;
mod raw_log;
mod tx_log;
pub use self::tx_log::{TxLog, TxLogId, TxLogStore};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,132 @@
// SPDX-License-Identifier: MPL-2.0
//! Compaction in `TxLsmTree`.
use core::marker::PhantomData;
use super::{
mem_table::ValueEx, sstable::SSTable, tx_lsm_tree::SSTABLE_CAPACITY, LsmLevel, RecordKey,
RecordValue, SyncId, TxEventListener,
};
use crate::{
layers::{bio::BlockSet, log::TxLogStore},
os::{JoinHandle, Mutex},
prelude::*,
};
/// A `Compactor` is currently used for asynchronous compaction
/// and specific compaction algorithm of `TxLsmTree`.
pub(super) struct Compactor<K, V> {
handle: Mutex<Option<JoinHandle<Result<()>>>>,
phantom: PhantomData<(K, V)>,
}
impl<K: RecordKey<K>, V: RecordValue> Compactor<K, V> {
/// Create a new `Compactor` instance.
pub fn new() -> Self {
Self {
handle: Mutex::new(None),
phantom: PhantomData,
}
}
/// Record current compaction thread handle.
pub fn record_handle(&self, handle: JoinHandle<Result<()>>) {
let mut handle_opt = self.handle.lock();
assert!(handle_opt.is_none());
let _ = handle_opt.insert(handle);
}
/// Wait until the compaction is finished.
pub fn wait_compaction(&self) -> Result<()> {
if let Some(handle) = self.handle.lock().take() {
handle.join().unwrap()
} else {
Ok(())
}
}
/// Core function for compacting overlapped records and building new SSTs.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn compact_records_and_build_ssts<D: BlockSet + 'static>(
upper_records: impl Iterator<Item = (K, ValueEx<V>)>,
lower_records: impl Iterator<Item = (K, ValueEx<V>)>,
tx_log_store: &Arc<TxLogStore<D>>,
event_listener: &Arc<dyn TxEventListener<K, V>>,
to_level: LsmLevel,
sync_id: SyncId,
) -> Result<Vec<SSTable<K, V>>> {
let mut created_ssts = Vec::new();
let mut upper_iter = upper_records.peekable();
let mut lower_iter = lower_records.peekable();
loop {
let mut record_cnt = 0;
let records_iter = core::iter::from_fn(|| {
if record_cnt == SSTABLE_CAPACITY {
return None;
}
record_cnt += 1;
match (upper_iter.peek(), lower_iter.peek()) {
(Some((upper_k, _)), Some((lower_k, _))) => match upper_k.cmp(lower_k) {
core::cmp::Ordering::Less => upper_iter.next(),
core::cmp::Ordering::Greater => lower_iter.next(),
core::cmp::Ordering::Equal => {
let (k, new_v_ex) = upper_iter.next().unwrap();
let (_, old_v_ex) = lower_iter.next().unwrap();
let (next_v_ex, dropped_v_opt) =
Self::compact_value_ex(new_v_ex, old_v_ex);
if let Some(dropped_v) = dropped_v_opt {
event_listener.on_drop_record(&(k, dropped_v)).unwrap();
}
Some((k, next_v_ex))
}
},
(Some(_), None) => upper_iter.next(),
(None, Some(_)) => lower_iter.next(),
(None, None) => None,
}
});
let mut records_iter = records_iter.peekable();
if records_iter.peek().is_none() {
break;
}
let new_log = tx_log_store.create_log(to_level.bucket())?;
let new_sst = SSTable::build(records_iter, sync_id, &new_log, None)?;
created_ssts.push(new_sst);
}
Ok(created_ssts)
}
/// Compact two `ValueEx<V>`s with the same key, returning
/// the compacted value and the dropped value if any.
fn compact_value_ex(new: ValueEx<V>, old: ValueEx<V>) -> (ValueEx<V>, Option<V>) {
match (new, old) {
(ValueEx::Synced(new_v), ValueEx::Synced(old_v)) => {
(ValueEx::Synced(new_v), Some(old_v))
}
(ValueEx::Unsynced(new_v), ValueEx::Synced(old_v)) => {
(ValueEx::SyncedAndUnsynced(old_v, new_v), None)
}
(ValueEx::Unsynced(new_v), ValueEx::Unsynced(old_v)) => {
(ValueEx::Unsynced(new_v), Some(old_v))
}
(ValueEx::Unsynced(new_v), ValueEx::SyncedAndUnsynced(old_sv, old_usv)) => {
(ValueEx::SyncedAndUnsynced(old_sv, new_v), Some(old_usv))
}
(ValueEx::SyncedAndUnsynced(new_sv, new_usv), ValueEx::Synced(old_sv)) => {
(ValueEx::SyncedAndUnsynced(new_sv, new_usv), Some(old_sv))
}
_ => {
unreachable!()
}
}
}
}

View File

@ -0,0 +1,402 @@
// SPDX-License-Identifier: MPL-2.0
//! MemTable.
use core::ops::Range;
use super::{tx_lsm_tree::OnDropRecodeFn, AsKV, RangeQueryCtx, RecordKey, RecordValue, SyncId};
use crate::{
os::{BTreeMap, Condvar, CvarMutex, Mutex, RwLock, RwLockReadGuard},
prelude::*,
};
/// Manager for an mutable `MemTable` and an immutable `MemTable`
/// in a `TxLsmTree`.
pub(super) struct MemTableManager<K: RecordKey<K>, V> {
mutable: Mutex<MemTable<K, V>>,
immutable: RwLock<MemTable<K, V>>, // Read-only most of the time
cvar: Condvar,
is_full: CvarMutex<bool>,
}
/// MemTable for LSM-Tree.
///
/// Manages organized key-value records in memory with a capacity.
/// Each `MemTable` is sync-aware (tagged with current sync ID).
/// Both synced and unsynced records can co-exist.
/// Also supports user-defined callback when a record is dropped.
pub(super) struct MemTable<K: RecordKey<K>, V> {
table: BTreeMap<K, ValueEx<V>>,
size: usize,
cap: usize,
sync_id: SyncId,
unsynced_range: Option<Range<K>>,
on_drop_record: Option<Arc<OnDropRecodeFn<K, V>>>,
}
/// An extended value which is sync-aware.
/// At most one unsynced and one synced records can coexist at the same time.
#[derive(Clone, Debug)]
pub(super) enum ValueEx<V> {
Synced(V),
Unsynced(V),
SyncedAndUnsynced(V, V),
}
impl<K: RecordKey<K>, V: RecordValue> MemTableManager<K, V> {
/// Creates a new `MemTableManager` given the current master sync ID,
/// the capacity and the callback when dropping records.
pub fn new(
sync_id: SyncId,
capacity: usize,
on_drop_record_in_memtable: Option<Arc<OnDropRecodeFn<K, V>>>,
) -> Self {
let mutable = Mutex::new(MemTable::new(
capacity,
sync_id,
on_drop_record_in_memtable.clone(),
));
let immutable = RwLock::new(MemTable::new(capacity, sync_id, on_drop_record_in_memtable));
Self {
mutable,
immutable,
cvar: Condvar::new(),
is_full: CvarMutex::new(false),
}
}
/// Gets the target value of the given key from the `MemTable`s.
pub fn get(&self, key: &K) -> Option<V> {
if let Some(value) = self.mutable.lock().get(key) {
return Some(*value);
}
if let Some(value) = self.immutable.read().get(key) {
return Some(*value);
}
None
}
/// Gets the range of values from the `MemTable`s.
pub fn get_range(&self, range_query_ctx: &mut RangeQueryCtx<K, V>) -> bool {
let is_completed = self.mutable.lock().get_range(range_query_ctx);
if is_completed {
return is_completed;
}
self.immutable.read().get_range(range_query_ctx)
}
/// Puts a key-value pair into the mutable `MemTable`, and
/// return whether the mutable `MemTable` is full.
pub fn put(&self, key: K, value: V) -> bool {
let mut is_full = self.is_full.lock().unwrap();
while *is_full {
is_full = self.cvar.wait(is_full).unwrap();
}
debug_assert!(!*is_full);
let mut mutable = self.mutable.lock();
let _ = mutable.put(key, value);
if mutable.at_capacity() {
*is_full = true;
}
*is_full
}
/// Sync the mutable `MemTable` with the given sync ID.
pub fn sync(&self, sync_id: SyncId) {
self.mutable.lock().sync(sync_id)
}
/// Switch two `MemTable`s. Should only be called in a situation that
/// the mutable `MemTable` becomes full and the immutable `MemTable` is
/// ready to be cleared.
pub fn switch(&self) -> Result<()> {
let mut is_full = self.is_full.lock().unwrap();
debug_assert!(*is_full);
let mut mutable = self.mutable.lock();
let sync_id = mutable.sync_id();
let mut immutable = self.immutable.write();
immutable.clear();
core::mem::swap(&mut *mutable, &mut *immutable);
debug_assert!(mutable.is_empty() && immutable.at_capacity());
// Update sync ID of the switched mutable `MemTable`
mutable.sync(sync_id);
*is_full = false;
self.cvar.notify_all();
Ok(())
}
/// Gets the immutable `MemTable` instance (read-only).
pub fn immutable_memtable(&self) -> RwLockReadGuard<MemTable<K, V>> {
self.immutable.read()
}
}
impl<K: RecordKey<K>, V: RecordValue> MemTable<K, V> {
/// Creates a new `MemTable`, given the capacity, the current sync ID,
/// and the callback of dropping record.
pub fn new(
cap: usize,
sync_id: SyncId,
on_drop_record: Option<Arc<OnDropRecodeFn<K, V>>>,
) -> Self {
Self {
table: BTreeMap::new(),
size: 0,
cap,
sync_id,
unsynced_range: None,
on_drop_record,
}
}
/// Gets the target value given the key.
pub fn get(&self, key: &K) -> Option<&V> {
let value_ex = self.table.get(key)?;
Some(value_ex.get())
}
/// Range query, returns whether the request is completed.
pub fn get_range(&self, range_query_ctx: &mut RangeQueryCtx<K, V>) -> bool {
debug_assert!(!range_query_ctx.is_completed());
let target_range = range_query_ctx.range_uncompleted().unwrap();
for (k, v_ex) in self.table.range(target_range) {
range_query_ctx.complete(*k, *v_ex.get());
}
range_query_ctx.is_completed()
}
/// Puts a new K-V record to the table, drop the old one.
pub fn put(&mut self, key: K, value: V) -> Option<V> {
let dropped_value = if let Some(value_ex) = self.table.get_mut(&key) {
if let Some(dropped) = value_ex.put(value) {
let _ = self
.on_drop_record
.as_ref()
.map(|on_drop_record| on_drop_record(&(key, dropped)));
Some(dropped)
} else {
self.size += 1;
None
}
} else {
let _ = self.table.insert(key, ValueEx::new(value));
self.size += 1;
None
};
if let Some(range) = self.unsynced_range.as_mut() {
if range.is_empty() {
*range = key..key + 1;
} else {
let start = key.min(range.start);
let end = (key + 1).max(range.end);
*range = start..end;
}
}
dropped_value
}
/// Sync the table, update the sync ID, drop the replaced one.
pub fn sync(&mut self, sync_id: SyncId) {
debug_assert!(self.sync_id <= sync_id);
if self.sync_id == sync_id {
return;
}
let filter_unsynced: Box<dyn Iterator<Item = _>> = if let Some(range) = &self.unsynced_range
{
Box::new(
self.table
.range_mut(range.clone())
.filter(|(_, v_ex)| v_ex.contains_unsynced()),
)
} else {
Box::new(
self.table
.iter_mut()
.filter(|(_, v_ex)| v_ex.contains_unsynced()),
)
};
for (k, v_ex) in filter_unsynced {
if let Some(dropped) = v_ex.sync() {
let _ = self
.on_drop_record
.as_ref()
.map(|on_drop_record| on_drop_record(&(*k, dropped)));
self.size -= 1;
}
}
self.sync_id = sync_id;
// Insert an empty range upon first sync
let _ = self
.unsynced_range
.get_or_insert_with(|| K::new_uninit()..K::new_uninit());
}
/// Return the sync ID of this table.
pub fn sync_id(&self) -> SyncId {
self.sync_id
}
/// Return an iterator over the table.
pub fn iter(&self) -> impl Iterator<Item = (&K, &ValueEx<V>)> {
self.table.iter()
}
/// Return the number of records in the table.
pub fn size(&self) -> usize {
self.size
}
/// Return whether the table is empty.
pub fn is_empty(&self) -> bool {
self.size == 0
}
/// Return whether the table is full.
pub fn at_capacity(&self) -> bool {
self.size >= self.cap
}
/// Clear all records from the table.
pub fn clear(&mut self) {
self.table.clear();
self.size = 0;
self.unsynced_range = None;
}
}
impl<V: RecordValue> ValueEx<V> {
/// Creates a new unsynced value.
pub fn new(value: V) -> Self {
Self::Unsynced(value)
}
/// Gets the most recent value.
pub fn get(&self) -> &V {
match self {
Self::Synced(v) => v,
Self::Unsynced(v) => v,
Self::SyncedAndUnsynced(_, v) => v,
}
}
/// Puts a new value, return the replaced value if any.
fn put(&mut self, value: V) -> Option<V> {
let existed = core::mem::take(self);
match existed {
ValueEx::Synced(v) => {
*self = Self::SyncedAndUnsynced(v, value);
None
}
ValueEx::Unsynced(v) => {
*self = Self::Unsynced(value);
Some(v)
}
ValueEx::SyncedAndUnsynced(sv, usv) => {
*self = Self::SyncedAndUnsynced(sv, value);
Some(usv)
}
}
}
/// Sync the value, return the replaced value if any.
fn sync(&mut self) -> Option<V> {
debug_assert!(self.contains_unsynced());
let existed = core::mem::take(self);
match existed {
ValueEx::Unsynced(v) => {
*self = Self::Synced(v);
None
}
ValueEx::SyncedAndUnsynced(sv, usv) => {
*self = Self::Synced(usv);
Some(sv)
}
ValueEx::Synced(_) => unreachable!(),
}
}
/// Whether the value contains an unsynced value.
pub fn contains_unsynced(&self) -> bool {
match self {
ValueEx::Unsynced(_) | ValueEx::SyncedAndUnsynced(_, _) => true,
ValueEx::Synced(_) => false,
}
}
}
impl<V: RecordValue> Default for ValueEx<V> {
fn default() -> Self {
Self::Unsynced(V::new_uninit())
}
}
impl<K: RecordKey<K>, V: RecordValue> Debug for MemTableManager<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemTableManager")
.field("mutable_memtable_size", &self.mutable.lock().size())
.field("immutable_memtable_size", &self.immutable_memtable().size())
.finish()
}
}
#[cfg(test)]
mod tests {
use core::sync::atomic::{AtomicU16, Ordering};
use super::*;
#[test]
fn memtable_fns() -> Result<()> {
impl RecordValue for u16 {}
let drop_count = Arc::new(AtomicU16::new(0));
let dc = drop_count.clone();
let drop_fn = move |_: &dyn AsKV<usize, u16>| {
dc.fetch_add(1, Ordering::Relaxed);
};
let mut table = MemTable::<usize, u16>::new(4, 0, Some(Arc::new(drop_fn)));
table.put(1, 11);
table.put(2, 12);
table.put(2, 22);
assert_eq!(drop_count.load(Ordering::Relaxed), 1);
assert_eq!(table.size(), 2);
assert_eq!(table.at_capacity(), false);
table.sync(1);
table.put(2, 32);
assert_eq!(table.size(), 3);
assert_eq!(*table.get(&2).unwrap(), 32);
table.sync(2);
assert_eq!(drop_count.load(Ordering::Relaxed), 2);
table.put(2, 52);
table.put(3, 13);
assert_eq!(table.at_capacity(), true);
let mut range_query_ctx = RangeQueryCtx::new(2, 2);
assert_eq!(table.get_range(&mut range_query_ctx), true);
assert_eq!(range_query_ctx.into_results(), vec![(2, 52), (3, 13)]);
assert_eq!(table.sync_id(), 2);
table.clear();
assert_eq!(table.is_empty(), true);
Ok(())
}
}

View File

@ -0,0 +1,79 @@
// SPDX-License-Identifier: MPL-2.0
//! The layer of transactional Lsm-Tree.
//!
//! This module provides the implementation for `TxLsmTree`.
//! `TxLsmTree` is similar to general-purpose LSM-Tree, supporting `put()`, `get()`, `get_range()`
//! key-value records, which are managed in MemTables and SSTables.
//!
//! `TxLsmTree` is transactional in the sense that
//! 1) it supports `sync()` that guarantees changes are persisted atomically and irreversibly,
//! synchronized records and unsynchronized records can co-existed.
//! 2) its internal data is securely stored in `TxLogStore` (L3) and updated in transactions for consistency,
//! WALs and SSTables are stored and managed in `TxLogStore`.
//!
//! `TxLsmTree` supports piggybacking callbacks during compaction and recovery.
//!
//! # Usage Example
//!
//! Create a `TxLsmTree` then put some records into it.
//!
//! ```
//! // Prepare an underlying disk (implement `BlockSet`) first
//! let nblocks = 1024;
//! let mem_disk = MemDisk::create(nblocks)?;
//!
//! // Prepare an underlying `TxLogStore` (L3) for storing WALs and SSTs
//! let tx_log_store = Arc::new(TxLogStore::format(mem_disk)?);
//!
//! // Create a `TxLsmTree` with the created `TxLogStore`
//! let tx_lsm_tree: TxLsmTree<BlockId, String, MemDisk> =
//! TxLsmTree::format(tx_log_store, Arc::new(YourFactory), None)?;
//!
//! // Put some key-value records into the tree
//! for i in 0..10 {
//! let k = i as BlockId;
//! let v = i.to_string();
//! tx_lsm_tree.put(k, v)?;
//! }
//!
//! // Issue a sync operation to the tree to ensure persistency
//! tx_lsm_tree.sync()?;
//!
//! // Use `get()` (or `get_range()`) to query the tree
//! let target_value = tx_lsm_tree.get(&5).unwrap();
//! // Check the previously put value
//! assert_eq(target_value, "5");
//!
//! // `TxLsmTree` supports user-defined per-TX callbacks
//! struct YourFactory;
//! struct YourListener;
//!
//! impl<K, V> TxEventListenerFactory<K, V> for YourFactory {
//! // Support create per-TX (upon compaction or upon recovery) listener
//! fn new_event_listener(&self, tx_type: TxType) -> Arc<dyn TxEventListener<K, V>> {
//! Arc::new(YourListener::new(tx_type))
//! }
//! }
//!
//! // Support defining callbacks when record is added or drop, or
//! // at some critical points during a TX
//! impl<K, V> TxEventListener<K, V> for YourListener {
//! /* details omitted, see the API for more */
//! }
//! ```
mod compaction;
mod mem_table;
mod range_query_ctx;
mod sstable;
mod tx_lsm_tree;
mod wal;
pub use self::{
range_query_ctx::RangeQueryCtx,
tx_lsm_tree::{
AsKV, LsmLevel, RecordKey, RecordValue, SyncId, SyncIdStore, TxEventListener,
TxEventListenerFactory, TxLsmTree, TxType,
},
};

View File

@ -0,0 +1,96 @@
// SPDX-License-Identifier: MPL-2.0
// Context for range query.
use core::ops::RangeInclusive;
use super::{RecordKey, RecordValue};
use crate::{prelude::*, util::BitMap};
/// Context for a range query request.
/// It tracks the completing process of each slot within the range.
/// A "slot" indicates one specific key-value pair of the query.
#[derive(Debug)]
pub struct RangeQueryCtx<K, V> {
start: K,
num_values: usize,
complete_table: BitMap,
min_uncompleted: usize,
res: Vec<(K, V)>,
}
impl<K: RecordKey<K>, V: RecordValue> RangeQueryCtx<K, V> {
/// Create a new context with the given start key,
/// and the number of values for query.
pub fn new(start: K, num_values: usize) -> Self {
Self {
start,
num_values,
complete_table: BitMap::repeat(false, num_values),
min_uncompleted: 0,
res: Vec::with_capacity(num_values),
}
}
/// Gets the uncompleted range within the whole, returns `None`
/// if all slots are already completed.
pub fn range_uncompleted(&self) -> Option<RangeInclusive<K>> {
if self.is_completed() {
return None;
}
debug_assert!(self.min_uncompleted < self.num_values);
let first_uncompleted = self.start + self.min_uncompleted;
let last_uncompleted = self.start + self.complete_table.last_zero()?;
Some(first_uncompleted..=last_uncompleted)
}
/// Whether the uncompleted range contains the target key.
pub fn contains_uncompleted(&self, key: &K) -> bool {
let nth = *key - self.start;
nth < self.num_values && !self.complete_table[nth]
}
/// Whether the range query context is completed, means
/// all slots are filled with the corresponding values.
pub fn is_completed(&self) -> bool {
self.min_uncompleted == self.num_values
}
/// Complete one slot within the range, with the specific
/// key and the queried value.
pub fn complete(&mut self, key: K, value: V) {
let nth = key - self.start;
if self.complete_table[nth] {
return;
}
self.res.push((key, value));
self.complete_table.set(nth, true);
self.update_min_uncompleted(nth);
}
/// Mark the specific slot as completed.
pub fn mark_completed(&mut self, key: K) {
let nth = key - self.start;
self.complete_table.set(nth, true);
self.update_min_uncompleted(nth);
}
/// Turn the context into final results.
pub fn into_results(self) -> Vec<(K, V)> {
debug_assert!(self.is_completed());
self.res
}
fn update_min_uncompleted(&mut self, completed_nth: usize) {
if self.min_uncompleted == completed_nth {
if let Some(next_uncompleted) = self.complete_table.first_zero(completed_nth) {
self.min_uncompleted = next_uncompleted;
} else {
// Indicate all slots are completed
self.min_uncompleted = self.num_values;
}
}
}
}

View File

@ -0,0 +1,779 @@
// SPDX-License-Identifier: MPL-2.0
//! Sorted String Table.
use alloc::vec;
use core::{marker::PhantomData, mem::size_of, num::NonZeroUsize, ops::RangeInclusive};
use lru::LruCache;
use ostd_pod::Pod;
use super::{
mem_table::ValueEx, tx_lsm_tree::AsKVex, RangeQueryCtx, RecordKey, RecordValue, SyncId,
TxEventListener,
};
use crate::{
layers::{
bio::{BlockSet, Buf, BufMut, BufRef, BID_SIZE},
log::{TxLog, TxLogId, TxLogStore},
},
os::Mutex,
prelude::*,
};
/// Sorted String Table (SST) for `TxLsmTree`.
///
/// Responsible for storing, managing key-value records on a `TxLog` (L3).
/// Records are serialized, sorted, organized on the `TxLog`.
/// Supports three access modes: point query, range query and whole scan.
pub(super) struct SSTable<K, V> {
id: TxLogId,
footer: Footer<K>,
cache: Mutex<LruCache<BlockId, Arc<RecordBlock>>>,
phantom: PhantomData<(K, V)>,
}
/// Footer of a `SSTable`, contains metadata of itself
/// index entries for locating record blocks.
#[derive(Debug)]
struct Footer<K> {
meta: FooterMeta,
index: Vec<IndexEntry<K>>,
}
/// Footer metadata to describe a `SSTable`.
#[repr(C)]
#[derive(Copy, Clone, Pod, Debug)]
struct FooterMeta {
num_index: u16,
index_nblocks: u16,
total_records: u32,
record_block_size: u32,
sync_id: SyncId,
}
const FOOTER_META_SIZE: usize = size_of::<FooterMeta>();
/// Index entry to describe a `RecordBlock` in a `SSTable`.
#[derive(Debug)]
struct IndexEntry<K> {
pos: BlockId,
first: K,
last: K,
}
/// A block full of serialized records.
struct RecordBlock {
buf: Vec<u8>,
}
const RECORD_BLOCK_NBLOCKS: usize = 32;
/// The size of a `RecordBlock`, which is a multiple of `BLOCK_SIZE`.
const RECORD_BLOCK_SIZE: usize = RECORD_BLOCK_NBLOCKS * BLOCK_SIZE;
/// Accessor for a query.
enum QueryAccessor<K> {
Point(K),
Range(RangeInclusive<K>),
}
/// Iterator over `RecordBlock` for query purpose.
struct BlockQueryIter<'a, K, V> {
block: &'a RecordBlock,
offset: usize,
accessor: &'a QueryAccessor<K>,
phantom: PhantomData<(K, V)>,
}
/// Accessor for a whole table scan.
struct ScanAccessor<'a, K, V> {
all_synced: bool,
discard_unsynced: bool,
event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>,
}
/// Iterator over `RecordBlock` for scan purpose.
struct BlockScanIter<'a, K, V> {
block: Arc<RecordBlock>,
offset: usize,
accessor: ScanAccessor<'a, K, V>,
}
/// Iterator over `SSTable`.
pub(super) struct SstIter<'a, K, V, D> {
sst: &'a SSTable<K, V>,
curr_nth_index: usize,
curr_rb_iter: Option<BlockScanIter<'a, K, V>>,
tx_log_store: &'a Arc<TxLogStore<D>>,
}
/// Format on a `TxLog`:
///
/// ```text
/// | [Record] | [Record] |...| Footer |
/// |K|flag|V(V)| ... | [Record] |...| [IndexEntry] | FooterMeta |
/// |RECORD_BLOCK_SIZE|RECORD_BLOCK_SIZE|...| |
/// ```
impl<K: RecordKey<K>, V: RecordValue> SSTable<K, V> {
const K_SIZE: usize = size_of::<K>();
const V_SIZE: usize = size_of::<V>();
const FLAG_SIZE: usize = size_of::<RecordFlag>();
const MIN_RECORD_SIZE: usize = BID_SIZE + Self::FLAG_SIZE + Self::V_SIZE;
const MAX_RECORD_SIZE: usize = BID_SIZE + Self::FLAG_SIZE + 2 * Self::V_SIZE;
const INDEX_ENTRY_SIZE: usize = BID_SIZE + 2 * Self::K_SIZE;
const CACHE_CAP: usize = 1024;
/// Return the ID of this `SSTable`, which is the same ID
/// to the underlying `TxLog`.
pub fn id(&self) -> TxLogId {
self.id
}
/// Return the sync ID of this `SSTable`, it may be smaller than the
/// current master sync ID.
pub fn sync_id(&self) -> SyncId {
self.footer.meta.sync_id
}
/// The range of keys covered by this `SSTable`.
pub fn range(&self) -> RangeInclusive<K> {
RangeInclusive::new(
self.footer.index[0].first,
self.footer.index[self.footer.meta.num_index as usize - 1].last,
)
}
/// Whether the target key is within the range, "within the range" doesn't mean
/// the `SSTable` do have this key.
pub fn is_within_range(&self, key: &K) -> bool {
self.range().contains(key)
}
/// Whether the target range is overlapped with the range of this `SSTable`.
pub fn overlap_with(&self, rhs_range: &RangeInclusive<K>) -> bool {
let lhs_range = self.range();
!(lhs_range.end() < rhs_range.start() || lhs_range.start() > rhs_range.end())
}
// Accessing functions below
/// Point query.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn access_point<D: BlockSet + 'static>(
&self,
key: &K,
tx_log_store: &Arc<TxLogStore<D>>,
) -> Result<V> {
debug_assert!(self.range().contains(key));
let target_rb_pos = self
.footer
.index
.iter()
.find_map(|entry| {
if entry.is_within_range(key) {
Some(entry.pos)
} else {
None
}
})
.ok_or(Error::with_msg(NotFound, "target key not found in sst"))?;
let accessor = QueryAccessor::Point(*key);
let target_rb = self.target_record_block(target_rb_pos, tx_log_store)?;
let mut iter = BlockQueryIter::<'_, K, V> {
block: &target_rb,
offset: 0,
accessor: &accessor,
phantom: PhantomData,
};
iter.find_map(|(k, v_opt)| if k == *key { v_opt } else { None })
.ok_or(Error::with_msg(NotFound, "target value not found in SST"))
}
/// Range query.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn access_range<D: BlockSet + 'static>(
&self,
range_query_ctx: &mut RangeQueryCtx<K, V>,
tx_log_store: &Arc<TxLogStore<D>>,
) -> Result<()> {
debug_assert!(!range_query_ctx.is_completed());
let range_uncompleted = range_query_ctx.range_uncompleted().unwrap();
let target_rbs = self.footer.index.iter().filter_map(|entry| {
if entry.overlap_with(&range_uncompleted) {
Some(entry.pos)
} else {
None
}
});
let accessor = QueryAccessor::Range(range_uncompleted.clone());
for target_rb_pos in target_rbs {
let target_rb = self.target_record_block(target_rb_pos, tx_log_store)?;
let iter = BlockQueryIter::<'_, K, V> {
block: &target_rb,
offset: 0,
accessor: &accessor,
phantom: PhantomData,
};
let targets: Vec<_> = iter
.filter_map(|(k, v_opt)| {
if range_uncompleted.contains(&k) {
Some((k, v_opt.unwrap()))
} else {
None
}
})
.collect();
for (target_k, target_v) in targets {
range_query_ctx.complete(target_k, target_v);
}
}
Ok(())
}
/// Locate the target record block given its position, it
/// resides in either the cache or the log.
fn target_record_block<D: BlockSet + 'static>(
&self,
target_pos: BlockId,
tx_log_store: &Arc<TxLogStore<D>>,
) -> Result<Arc<RecordBlock>> {
let mut cache = self.cache.lock();
if let Some(cached_rb) = cache.get(&target_pos) {
Ok(cached_rb.clone())
} else {
let mut rb = RecordBlock::from_buf(vec![0; RECORD_BLOCK_SIZE]);
// TODO: Avoid opening the log on every call
let tx_log = tx_log_store.open_log(self.id, false)?;
tx_log.read(target_pos, BufMut::try_from(rb.as_mut_slice()).unwrap())?;
let rb = Arc::new(rb);
cache.put(target_pos, rb.clone());
Ok(rb)
}
}
/// Return the iterator over this `SSTable`.
/// The given `event_listener` (optional) is used on dropping records
/// during iteration.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn iter<'a, D: BlockSet + 'static>(
&'a self,
sync_id: SyncId,
discard_unsynced: bool,
tx_log_store: &'a Arc<TxLogStore<D>>,
event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>,
) -> SstIter<'a, K, V, D> {
let all_synced = sync_id > self.sync_id();
let accessor = ScanAccessor {
all_synced,
discard_unsynced,
event_listener,
};
let first_rb = self
.target_record_block(self.footer.index[0].pos, tx_log_store)
.unwrap();
SstIter {
sst: self,
curr_nth_index: 0,
curr_rb_iter: Some(BlockScanIter {
block: first_rb,
offset: 0,
accessor,
}),
tx_log_store,
}
}
/// Scan the whole SST and collect all records.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn access_scan<D: BlockSet + 'static>(
&self,
sync_id: SyncId,
discard_unsynced: bool,
tx_log_store: &Arc<TxLogStore<D>>,
event_listener: Option<&Arc<dyn TxEventListener<K, V>>>,
) -> Result<Vec<(K, ValueEx<V>)>> {
let all_records = self
.iter(sync_id, discard_unsynced, tx_log_store, event_listener)
.collect();
Ok(all_records)
}
// Building functions below
/// Builds a SST given a bunch of records, after the SST becomes immutable.
/// The given `event_listener` (optional) is used on adding records.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn build<'a, D: BlockSet + 'static, I, KVex>(
records_iter: I,
sync_id: SyncId,
tx_log: &'a Arc<TxLog<D>>,
event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>,
) -> Result<Self>
where
I: Iterator<Item = KVex>,
KVex: AsKVex<K, V>,
Self: 'a,
{
let mut cache = LruCache::new(NonZeroUsize::new(Self::CACHE_CAP).unwrap());
let (total_records, index_vec) =
Self::build_record_blocks(records_iter, tx_log, &mut cache, event_listener)?;
let footer = Self::build_footer::<D>(index_vec, total_records, sync_id, tx_log)?;
Ok(Self {
id: tx_log.id(),
footer,
cache: Mutex::new(cache),
phantom: PhantomData,
})
}
/// Builds all the record blocks from the given records. Put the blocks to the log
/// and the cache.
fn build_record_blocks<'a, D: BlockSet + 'static, I, KVex>(
records_iter: I,
tx_log: &'a TxLog<D>,
cache: &mut LruCache<BlockId, Arc<RecordBlock>>,
event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>,
) -> Result<(usize, Vec<IndexEntry<K>>)>
where
I: Iterator<Item = KVex>,
KVex: AsKVex<K, V>,
Self: 'a,
{
let mut index_vec = Vec::new();
let mut total_records = 0;
let mut pos = 0 as BlockId;
let (mut first_k, mut curr_k) = (None, None);
let mut inner_offset = 0;
let mut block_buf = Vec::with_capacity(RECORD_BLOCK_SIZE);
for kv_ex in records_iter {
let (key, value_ex) = (*kv_ex.key(), kv_ex.value_ex());
total_records += 1;
if inner_offset == 0 {
debug_assert!(block_buf.is_empty());
let _ = first_k.insert(key);
}
let _ = curr_k.insert(key);
block_buf.extend_from_slice(key.as_bytes());
inner_offset += Self::K_SIZE;
match value_ex {
ValueEx::Synced(v) => {
block_buf.push(RecordFlag::Synced as u8);
block_buf.extend_from_slice(v.as_bytes());
if let Some(listener) = event_listener {
listener.on_add_record(&(&key, v))?;
}
inner_offset += 1 + Self::V_SIZE;
}
ValueEx::Unsynced(v) => {
block_buf.push(RecordFlag::Unsynced as u8);
block_buf.extend_from_slice(v.as_bytes());
if let Some(listener) = event_listener {
listener.on_add_record(&(&key, v))?;
}
inner_offset += 1 + Self::V_SIZE;
}
ValueEx::SyncedAndUnsynced(sv, usv) => {
block_buf.push(RecordFlag::SyncedAndUnsynced as u8);
block_buf.extend_from_slice(sv.as_bytes());
block_buf.extend_from_slice(usv.as_bytes());
if let Some(listener) = event_listener {
listener.on_add_record(&(&key, sv))?;
listener.on_add_record(&(&key, usv))?;
}
inner_offset += Self::MAX_RECORD_SIZE;
}
}
let cap_remained = RECORD_BLOCK_SIZE - inner_offset;
if cap_remained >= Self::MAX_RECORD_SIZE {
continue;
}
let index_entry = IndexEntry {
pos,
first: first_k.unwrap(),
last: key,
};
build_one_record_block(&index_entry, &mut block_buf, tx_log, cache)?;
index_vec.push(index_entry);
pos += RECORD_BLOCK_NBLOCKS;
inner_offset = 0;
block_buf.clear();
}
debug_assert!(total_records > 0);
if !block_buf.is_empty() {
let last_entry = IndexEntry {
pos,
first: first_k.unwrap(),
last: curr_k.unwrap(),
};
build_one_record_block(&last_entry, &mut block_buf, tx_log, cache)?;
index_vec.push(last_entry);
}
fn build_one_record_block<K: RecordKey<K>, D: BlockSet + 'static>(
entry: &IndexEntry<K>,
buf: &mut Vec<u8>,
tx_log: &TxLog<D>,
cache: &mut LruCache<BlockId, Arc<RecordBlock>>,
) -> Result<()> {
buf.resize(RECORD_BLOCK_SIZE, 0);
let record_block = RecordBlock::from_buf(buf.clone());
tx_log.append(BufRef::try_from(record_block.as_slice()).unwrap())?;
cache.put(entry.pos, Arc::new(record_block));
Ok(())
}
Ok((total_records, index_vec))
}
/// Builds the footer from the given index entries. The footer block will be appended
/// to the SST log's end.
fn build_footer<'a, D: BlockSet + 'static>(
index_vec: Vec<IndexEntry<K>>,
total_records: usize,
sync_id: SyncId,
tx_log: &'a TxLog<D>,
) -> Result<Footer<K>>
where
Self: 'a,
{
let footer_buf_len = align_up(
index_vec.len() * Self::INDEX_ENTRY_SIZE + FOOTER_META_SIZE,
BLOCK_SIZE,
);
let mut append_buf = Vec::with_capacity(footer_buf_len);
for entry in &index_vec {
append_buf.extend_from_slice(&entry.pos.to_le_bytes());
append_buf.extend_from_slice(entry.first.as_bytes());
append_buf.extend_from_slice(entry.last.as_bytes());
}
append_buf.resize(footer_buf_len, 0);
let meta = FooterMeta {
num_index: index_vec.len() as _,
index_nblocks: (footer_buf_len / BLOCK_SIZE) as _,
total_records: total_records as _,
record_block_size: RECORD_BLOCK_SIZE as _,
sync_id,
};
append_buf[footer_buf_len - FOOTER_META_SIZE..].copy_from_slice(meta.as_bytes());
tx_log.append(BufRef::try_from(&append_buf[..]).unwrap())?;
Ok(Footer {
meta,
index: index_vec,
})
}
/// Builds a SST from a `TxLog`, loads the footer and the index blocks.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn from_log<D: BlockSet + 'static>(tx_log: &Arc<TxLog<D>>) -> Result<Self> {
let nblocks = tx_log.nblocks();
let mut rbuf = Buf::alloc(1)?;
// Load footer block (last block)
tx_log.read(nblocks - 1, rbuf.as_mut())?;
let meta = FooterMeta::from_bytes(&rbuf.as_slice()[BLOCK_SIZE - FOOTER_META_SIZE..]);
let mut rbuf = Buf::alloc(meta.index_nblocks as _)?;
tx_log.read(nblocks - meta.index_nblocks as usize, rbuf.as_mut())?;
let mut index = Vec::with_capacity(meta.num_index as _);
let mut cache = LruCache::new(NonZeroUsize::new(Self::CACHE_CAP).unwrap());
let mut record_block = vec![0; RECORD_BLOCK_SIZE];
for i in 0..meta.num_index as _ {
let buf =
&rbuf.as_slice()[i * Self::INDEX_ENTRY_SIZE..(i + 1) * Self::INDEX_ENTRY_SIZE];
let pos = BlockId::from_le_bytes(buf[..BID_SIZE].try_into().unwrap());
let first = K::from_bytes(&buf[BID_SIZE..BID_SIZE + Self::K_SIZE]);
let last =
K::from_bytes(&buf[Self::INDEX_ENTRY_SIZE - Self::K_SIZE..Self::INDEX_ENTRY_SIZE]);
tx_log.read(pos, BufMut::try_from(&mut record_block[..]).unwrap())?;
let _ = cache.put(pos, Arc::new(RecordBlock::from_buf(record_block.clone())));
index.push(IndexEntry { pos, first, last })
}
let footer = Footer { meta, index };
Ok(Self {
id: tx_log.id(),
footer,
cache: Mutex::new(cache),
phantom: PhantomData,
})
}
}
impl<K: RecordKey<K>> IndexEntry<K> {
pub fn range(&self) -> RangeInclusive<K> {
self.first..=self.last
}
pub fn is_within_range(&self, key: &K) -> bool {
self.range().contains(key)
}
pub fn overlap_with(&self, rhs_range: &RangeInclusive<K>) -> bool {
let lhs_range = self.range();
!(lhs_range.end() < rhs_range.start() || lhs_range.start() > rhs_range.end())
}
}
impl RecordBlock {
pub fn from_buf(buf: Vec<u8>) -> Self {
debug_assert_eq!(buf.len(), RECORD_BLOCK_SIZE);
Self { buf }
}
pub fn as_slice(&self) -> &[u8] {
&self.buf
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.buf
}
}
impl<K: RecordKey<K>> QueryAccessor<K> {
pub fn hit_target(&self, target: &K) -> bool {
match self {
QueryAccessor::Point(k) => k == target,
QueryAccessor::Range(range) => range.contains(target),
}
}
}
impl<K: RecordKey<K>, V: RecordValue> Iterator for BlockQueryIter<'_, K, V> {
type Item = (K, Option<V>);
fn next(&mut self) -> Option<Self::Item> {
let mut offset = self.offset;
let buf_slice = &self.block.buf;
let (k_size, v_size) = (SSTable::<K, V>::K_SIZE, SSTable::<K, V>::V_SIZE);
if offset + SSTable::<K, V>::MIN_RECORD_SIZE > RECORD_BLOCK_SIZE {
return None;
}
let key = K::from_bytes(&buf_slice[offset..offset + k_size]);
offset += k_size;
let flag = RecordFlag::from(buf_slice[offset]);
offset += 1;
if flag == RecordFlag::Invalid {
return None;
}
let hit_target = self.accessor.hit_target(&key);
let value_opt = match flag {
RecordFlag::Synced | RecordFlag::Unsynced => {
let v_opt = if hit_target {
Some(V::from_bytes(&buf_slice[offset..offset + v_size]))
} else {
None
};
offset += v_size;
v_opt
}
RecordFlag::SyncedAndUnsynced => {
let v_opt = if hit_target {
Some(V::from_bytes(
&buf_slice[offset + v_size..offset + 2 * v_size],
))
} else {
None
};
offset += 2 * v_size;
v_opt
}
_ => unreachable!(),
};
self.offset = offset;
Some((key, value_opt))
}
}
impl<K: RecordKey<K>, V: RecordValue> Iterator for BlockScanIter<'_, K, V> {
type Item = (K, ValueEx<V>);
fn next(&mut self) -> Option<Self::Item> {
let mut offset = self.offset;
let buf_slice = &self.block.buf;
let (k_size, v_size) = (SSTable::<K, V>::K_SIZE, SSTable::<K, V>::V_SIZE);
let (all_synced, discard_unsynced, event_listener) = (
self.accessor.all_synced,
self.accessor.discard_unsynced,
&self.accessor.event_listener,
);
let (key, value_ex) = loop {
if offset + SSTable::<K, V>::MIN_RECORD_SIZE > RECORD_BLOCK_SIZE {
return None;
}
let key = K::from_bytes(&buf_slice[offset..offset + k_size]);
offset += k_size;
let flag = RecordFlag::from(buf_slice[offset]);
offset += 1;
if flag == RecordFlag::Invalid {
return None;
}
let v_ex = match flag {
RecordFlag::Synced => {
let v = V::from_bytes(&buf_slice[offset..offset + v_size]);
offset += v_size;
ValueEx::Synced(v)
}
RecordFlag::Unsynced => {
let v = V::from_bytes(&buf_slice[offset..offset + v_size]);
offset += v_size;
if all_synced {
ValueEx::Synced(v)
} else if discard_unsynced {
if let Some(listener) = event_listener {
listener.on_drop_record(&(key, v)).unwrap();
}
continue;
} else {
ValueEx::Unsynced(v)
}
}
RecordFlag::SyncedAndUnsynced => {
let sv = V::from_bytes(&buf_slice[offset..offset + v_size]);
offset += v_size;
let usv = V::from_bytes(&buf_slice[offset..offset + v_size]);
offset += v_size;
if all_synced {
if let Some(listener) = event_listener {
listener.on_drop_record(&(key, sv)).unwrap();
}
ValueEx::Synced(usv)
} else if discard_unsynced {
if let Some(listener) = event_listener {
listener.on_drop_record(&(key, usv)).unwrap();
}
ValueEx::Synced(sv)
} else {
ValueEx::SyncedAndUnsynced(sv, usv)
}
}
_ => unreachable!(),
};
break (key, v_ex);
};
self.offset = offset;
Some((key, value_ex))
}
}
impl<K: RecordKey<K>, V: RecordValue, D: BlockSet + 'static> Iterator for SstIter<'_, K, V, D> {
type Item = (K, ValueEx<V>);
fn next(&mut self) -> Option<Self::Item> {
// Iterate over the current record block first
if let Some(next) = self.curr_rb_iter.as_mut().unwrap().next() {
return Some(next);
}
let curr_rb_iter = self.curr_rb_iter.take().unwrap();
self.curr_nth_index += 1;
// Iteration goes to the end
if self.curr_nth_index >= self.sst.footer.meta.num_index as _ {
return None;
}
// Ready to iterate the next record block
let next_pos = self.sst.footer.index[self.curr_nth_index].pos;
let next_rb = self
.sst
.target_record_block(next_pos, self.tx_log_store)
.unwrap();
let mut next_rb_iter = BlockScanIter {
block: next_rb,
offset: 0,
accessor: curr_rb_iter.accessor,
};
let next = next_rb_iter.next()?;
let _ = self.curr_rb_iter.insert(next_rb_iter);
Some(next)
}
}
impl<K: Debug, V> Debug for SSTable<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SSTable")
.field("id", &self.id)
.field("footer", &self.footer.meta)
.field(
"range",
&RangeInclusive::new(
&self.footer.index[0].first,
&self.footer.index[self.footer.meta.num_index as usize - 1].last,
),
)
.finish()
}
}
/// Flag bit for records in SSTable.
#[derive(PartialEq, Eq, Debug)]
#[repr(u8)]
enum RecordFlag {
Synced = 7,
Unsynced = 11,
SyncedAndUnsynced = 19,
Invalid,
}
impl From<u8> for RecordFlag {
fn from(value: u8) -> Self {
match value {
7 => RecordFlag::Synced,
11 => RecordFlag::Unsynced,
19 => RecordFlag::SyncedAndUnsynced,
_ => RecordFlag::Invalid,
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,279 @@
// SPDX-License-Identifier: MPL-2.0
//! Transactions in WriteAhead Log.
use alloc::vec;
use core::{fmt::Debug, mem::size_of};
use ostd_pod::Pod;
use super::{AsKV, SyncId};
use crate::{
layers::{
bio::{BlockId, BlockSet, Buf, BufRef},
log::{TxLog, TxLogId, TxLogStore},
},
os::Mutex,
prelude::*,
tx::CurrentTx,
};
/// The bucket name of WAL.
pub(super) const BUCKET_WAL: &str = "WAL";
/// WAL append TX in `TxLsmTree`.
///
/// A `WalAppendTx` is used to append records, sync and discard WALs.
/// A WAL is storing, managing key-value records which are going to
/// put in `MemTable`. It's space is backed by a `TxLog` (L3).
#[derive(Clone)]
pub(super) struct WalAppendTx<D> {
inner: Arc<Mutex<WalTxInner<D>>>,
}
struct WalTxInner<D> {
/// The appended WAL of ongoing Tx.
appended_log: Option<Arc<TxLog<D>>>,
/// Current log ID of WAL for later use.
log_id: Option<TxLogId>,
/// Store current sync ID as the first record of WAL.
sync_id: SyncId,
/// A buffer to cache appended records.
record_buf: Vec<u8>,
/// Store for WALs.
tx_log_store: Arc<TxLogStore<D>>,
}
impl<D: BlockSet + 'static> WalAppendTx<D> {
const BUF_CAP: usize = 1024 * BLOCK_SIZE;
/// Prepare a new WAL TX.
pub fn new(store: &Arc<TxLogStore<D>>, sync_id: SyncId) -> Self {
Self {
inner: Arc::new(Mutex::new(WalTxInner {
appended_log: None,
log_id: None,
sync_id,
record_buf: Vec::with_capacity(Self::BUF_CAP),
tx_log_store: store.clone(),
})),
}
}
/// Append phase for an Append TX, mainly to append newly records to the WAL.
pub fn append<K: Pod, V: Pod>(&self, record: &dyn AsKV<K, V>) -> Result<()> {
let mut inner = self.inner.lock();
if inner.appended_log.is_none() {
inner.prepare()?;
}
{
let record_buf = &mut inner.record_buf;
record_buf.push(WalAppendFlag::Record as u8);
record_buf.extend_from_slice(record.key().as_bytes());
record_buf.extend_from_slice(record.value().as_bytes());
}
const MAX_RECORD_SIZE: usize = 49;
if inner.record_buf.len() <= Self::BUF_CAP - MAX_RECORD_SIZE {
return Ok(());
}
inner.align_record_buf();
let wal_tx = inner.tx_log_store.current_tx();
let wal_log = inner.appended_log.as_ref().unwrap();
self.flush_buf(&inner.record_buf, &wal_tx, wal_log)?;
inner.record_buf.clear();
Ok(())
}
/// Commit phase for an Append TX, mainly to commit (or abort) the TX.
/// After the committed WAL is sealed. Return the corresponding log ID.
///
/// # Panics
///
/// This method panics if current WAL's TX does not exist.
pub fn commit(&self) -> Result<TxLogId> {
let mut inner = self.inner.lock();
let wal_log = inner
.appended_log
.take()
.expect("current WAL TX must exist");
let wal_id = inner.log_id.take().unwrap();
debug_assert_eq!(wal_id, wal_log.id());
if !inner.record_buf.is_empty() {
inner.align_record_buf();
let wal_tx = inner.tx_log_store.current_tx();
self.flush_buf(&inner.record_buf, &wal_tx, &wal_log)?;
inner.record_buf.clear();
}
drop(wal_log);
inner.tx_log_store.current_tx().commit()?;
Ok(wal_id)
}
/// Appends current sync ID to WAL then commit the TX to ensure WAL's persistency.
/// Save the log ID for later appending.
pub fn sync(&self, sync_id: SyncId) -> Result<()> {
let mut inner = self.inner.lock();
if inner.appended_log.is_none() {
inner.prepare()?;
}
inner.record_buf.push(WalAppendFlag::Sync as u8);
inner.record_buf.extend_from_slice(&sync_id.to_le_bytes());
inner.sync_id = sync_id;
inner.align_record_buf();
let wal_log = inner.appended_log.take().unwrap();
self.flush_buf(
&inner.record_buf,
&inner.tx_log_store.current_tx(),
&wal_log,
)?;
inner.record_buf.clear();
drop(wal_log);
inner.tx_log_store.current_tx().commit()
}
/// Flushes the buffer to the backed log.
fn flush_buf(
&self,
record_buf: &[u8],
wal_tx: &CurrentTx<'_>,
log: &Arc<TxLog<D>>,
) -> Result<()> {
debug_assert!(!record_buf.is_empty() && record_buf.len() % BLOCK_SIZE == 0);
let res = wal_tx.context(|| {
let buf = BufRef::try_from(record_buf).unwrap();
log.append(buf)
});
if res.is_err() {
wal_tx.abort();
}
res
}
/// Collects the synced records only and the maximum sync ID in the WAL.
pub fn collect_synced_records_and_sync_id<K: Pod, V: Pod>(
wal: &TxLog<D>,
) -> Result<(Vec<(K, V)>, SyncId)> {
let nblocks = wal.nblocks();
let mut records = Vec::new();
// TODO: Allocate separate buffers for large WAL
let mut buf = Buf::alloc(nblocks)?;
wal.read(0 as BlockId, buf.as_mut())?;
let buf_slice = buf.as_slice();
let k_size = size_of::<K>();
let v_size = size_of::<V>();
let total_bytes = nblocks * BLOCK_SIZE;
let mut offset = 0;
let (mut max_sync_id, mut synced_len) = (None, 0);
loop {
const MIN_RECORD_SIZE: usize = 9;
if offset > total_bytes - MIN_RECORD_SIZE {
break;
}
let flag = WalAppendFlag::try_from(buf_slice[offset]);
offset += 1;
if flag.is_err() {
continue;
}
match flag.unwrap() {
WalAppendFlag::Record => {
let record = {
let k = K::from_bytes(&buf_slice[offset..offset + k_size]);
let v =
V::from_bytes(&buf_slice[offset + k_size..offset + k_size + v_size]);
offset += k_size + v_size;
(k, v)
};
records.push(record);
}
WalAppendFlag::Sync => {
let sync_id = SyncId::from_le_bytes(
buf_slice[offset..offset + size_of::<SyncId>()]
.try_into()
.unwrap(),
);
offset += size_of::<SyncId>();
let _ = max_sync_id.insert(sync_id);
synced_len = records.len();
}
}
}
if let Some(max_sync_id) = max_sync_id {
records.truncate(synced_len);
Ok((records, max_sync_id))
} else {
Ok((vec![], 0))
}
}
}
impl<D: BlockSet + 'static> WalTxInner<D> {
/// Prepare phase for an Append TX, mainly to create new TX and WAL.
pub fn prepare(&mut self) -> Result<()> {
debug_assert!(self.appended_log.is_none());
let appended_log = {
let store = &self.tx_log_store;
let wal_tx = store.new_tx();
let log_id_opt = self.log_id;
let res = wal_tx.context(|| {
if let Some(log_id) = log_id_opt {
store.open_log(log_id, true)
} else {
store.create_log(BUCKET_WAL)
}
});
if res.is_err() {
wal_tx.abort();
}
let wal_log = res?;
let _ = self.log_id.insert(wal_log.id());
wal_log
};
let _ = self.appended_log.insert(appended_log);
// Record the sync ID at the beginning of the WAL
debug_assert!(self.record_buf.is_empty());
self.record_buf.push(WalAppendFlag::Sync as u8);
self.record_buf
.extend_from_slice(&self.sync_id.to_le_bytes());
Ok(())
}
fn align_record_buf(&mut self) {
let aligned_len = align_up(self.record_buf.len(), BLOCK_SIZE);
self.record_buf.resize(aligned_len, 0);
}
}
/// Two content kinds in a WAL.
#[derive(PartialEq, Eq, Debug)]
#[repr(u8)]
enum WalAppendFlag {
Record = 13,
Sync = 23,
}
impl TryFrom<u8> for WalAppendFlag {
type Error = Error;
fn try_from(value: u8) -> Result<Self> {
match value {
13 => Ok(WalAppendFlag::Record),
23 => Ok(WalAppendFlag::Sync),
_ => Err(Error::new(InvalidArgs)),
}
}
}

View File

@ -0,0 +1,291 @@
// SPDX-License-Identifier: MPL-2.0
//! Block I/O (BIO).
use alloc::collections::VecDeque;
use core::{
any::{Any, TypeId},
ptr::NonNull,
sync::atomic::{AtomicUsize, Ordering},
};
use hashbrown::HashMap;
use crate::{
os::{Mutex, MutexGuard},
prelude::*,
Buf,
};
/// A queue for managing block I/O requests (`BioReq`).
/// It provides a concurrency-safe way to store and manage
/// block I/O requests that need to be processed by a block device.
pub struct BioReqQueue {
queue: Mutex<VecDeque<BioReq>>,
num_reqs: AtomicUsize,
}
impl BioReqQueue {
/// Create a new `BioReqQueue` instance.
pub fn new() -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
num_reqs: AtomicUsize::new(0),
}
}
/// Enqueue a block I/O request.
pub fn enqueue(&self, req: BioReq) -> Result<()> {
req.submit();
self.queue.lock().push_back(req);
self.num_reqs.fetch_add(1, Ordering::Release);
Ok(())
}
/// Dequeue a block I/O request.
pub fn dequeue(&self) -> Option<BioReq> {
if let Some(req) = self.queue.lock().pop_front() {
self.num_reqs.fetch_sub(1, Ordering::Release);
Some(req)
} else {
debug_assert_eq!(self.num_reqs.load(Ordering::Acquire), 0);
None
}
}
/// Returns the number of pending requests in this queue.
pub fn num_reqs(&self) -> usize {
self.num_reqs.load(Ordering::Acquire)
}
/// Returns whether there are no pending requests in this queue.
pub fn is_empty(&self) -> bool {
self.num_reqs() == 0
}
}
/// A block I/O request.
pub struct BioReq {
type_: BioType,
addr: BlockId,
nblocks: u32,
bufs: Mutex<Vec<Buf>>,
status: Mutex<BioStatus>,
on_complete: Option<BioReqOnCompleteFn>,
ext: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
}
/// The type of a block request.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BioType {
/// A read request.
Read,
/// A write request.
Write,
/// A sync request.
Sync,
}
/// A response from a block device.
pub type BioResp = Result<()>;
/// The type of the callback function invoked upon the completion of
/// a block I/O request.
pub type BioReqOnCompleteFn = fn(/* req = */ &BioReq, /* resp = */ &BioResp);
/// The status describing a block I/O request.
#[derive(Clone, Debug)]
enum BioStatus {
Init,
Submitted,
Completed(BioResp),
}
impl BioReq {
/// Returns the type of the request.
pub fn type_(&self) -> BioType {
self.type_
}
/// Returns the starting address of requested blocks.
///
/// The return value is meaningless if the request is not a read or write.
pub fn addr(&self) -> BlockId {
self.addr
}
/// Access the immutable buffers with a closure.
pub fn access_bufs_with<F, R>(&self, mut f: F) -> R
where
F: FnMut(&[Buf]) -> R,
{
let bufs = self.bufs.lock();
(f)(&bufs)
}
/// Access the mutable buffers with a closure.
pub(super) fn access_mut_bufs_with<F, R>(&self, mut f: F) -> R
where
F: FnMut(&mut [Buf]) -> R,
{
let mut bufs = self.bufs.lock();
(f)(&mut bufs)
}
/// Take the buffers out of the request.
pub(super) fn take_bufs(&self) -> Vec<Buf> {
let mut bufs = self.bufs.lock();
let mut ret_bufs = Vec::new();
core::mem::swap(&mut *bufs, &mut ret_bufs);
ret_bufs
}
/// Returns the number of buffers associated with the request.
///
/// If the request is a flush, then the returned value is meaningless.
pub fn nbufs(&self) -> usize {
self.bufs.lock().len()
}
/// Returns the number of blocks to read or write by this request.
///
/// If the request is a flush, then the returned value is meaningless.
pub fn nblocks(&self) -> usize {
self.nblocks as usize
}
/// Returns the extensions of the request.
///
/// The extensions of a request is a set of objects that may be added, removed,
/// or accessed by block devices and their users. Each of the extension objects
/// must have a different type. To avoid conflicts, it is recommended to use only
/// private types for the extension objects.
pub fn ext(&self) -> MutexGuard<HashMap<TypeId, Box<dyn Any + Send + Sync>>> {
self.ext.lock()
}
/// Update the status of the request to "completed" by giving the response
/// to the request.
///
/// After the invoking this API, the request is considered completed, which
/// means the request must have taken effect. For example, a completed read
/// request must have all its buffers filled with data.
///
/// # Panics
///
/// If the request has not been submitted yet, or has been completed already,
/// this method will panic.
pub(super) fn complete(&self, resp: BioResp) {
let mut status = self.status.lock();
match *status {
BioStatus::Submitted => {
if let Some(on_complete) = self.on_complete {
(on_complete)(self, &resp);
}
*status = BioStatus::Completed(resp);
}
_ => panic!("cannot complete before submitting or complete twice"),
}
}
/// Mark the request as submitted.
pub(super) fn submit(&self) {
let mut status = self.status.lock();
match *status {
BioStatus::Init => *status = BioStatus::Submitted,
_ => unreachable!(),
}
}
}
/// A builder for `BioReq`.
pub struct BioReqBuilder {
type_: BioType,
addr: Option<BlockId>,
bufs: Option<Vec<Buf>>,
on_complete: Option<BioReqOnCompleteFn>,
ext: Option<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
}
impl BioReqBuilder {
/// Creates a builder of a block request of the given type.
pub fn new(type_: BioType) -> Self {
Self {
type_,
addr: None,
bufs: None,
on_complete: None,
ext: None,
}
}
/// Specify the block address of the request.
pub fn addr(mut self, addr: BlockId) -> Self {
self.addr = Some(addr);
self
}
/// Give the buffers of the request.
pub fn bufs(mut self, bufs: Vec<Buf>) -> Self {
self.bufs = Some(bufs);
self
}
/// Specify a callback invoked when the request is complete.
pub fn on_complete(mut self, on_complete: BioReqOnCompleteFn) -> Self {
self.on_complete = Some(on_complete);
self
}
/// Add an extension object to the request.
pub fn ext<T: Any + Send + Sync + Sized>(mut self, obj: T) -> Self {
if self.ext.is_none() {
self.ext = Some(HashMap::new());
}
let _ = self
.ext
.as_mut()
.unwrap()
.insert(TypeId::of::<T>(), Box::new(obj));
self
}
/// Build the request.
pub fn build(mut self) -> BioReq {
let type_ = self.type_;
if type_ == BioType::Sync {
debug_assert!(
self.addr.is_none(),
"addr is only meaningful for a read or write",
);
debug_assert!(
self.bufs.is_none(),
"bufs is only meaningful for a read or write",
);
}
let addr = self.addr.unwrap_or(0 as BlockId);
let bufs = self.bufs.take().unwrap_or_default();
let nblocks = {
let nbytes = bufs
.iter()
.map(|buf| buf.as_slice().len())
.fold(0_usize, |sum, len| sum.saturating_add(len));
(nbytes / BLOCK_SIZE) as u32
};
let ext = self.ext.take().unwrap_or_default();
let on_complete = self.on_complete.take();
BioReq {
type_,
addr,
nblocks,
bufs: Mutex::new(bufs),
status: Mutex::new(BioStatus::Init),
on_complete,
ext: Mutex::new(ext),
}
}
}

View File

@ -0,0 +1,403 @@
// SPDX-License-Identifier: MPL-2.0
//! Block allocation.
use alloc::vec;
use core::{
mem::size_of,
num::NonZeroUsize,
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
};
use ostd_pod::Pod;
use serde::{Deserialize, Serialize};
use super::sworndisk::Hba;
use crate::{
layers::{
bio::{BlockSet, Buf, BufRef, BID_SIZE},
log::{TxLog, TxLogStore},
},
os::{BTreeMap, Condvar, CvarMutex, Mutex},
prelude::*,
util::BitMap,
};
/// The bucket name of block validity table.
const BUCKET_BLOCK_VALIDITY_TABLE: &str = "BVT";
/// The bucket name of block alloc/dealloc log.
const BUCKET_BLOCK_ALLOC_LOG: &str = "BAL";
/// Block validity table. Global allocator for `SwornDisk`,
/// which manages validities of user data blocks.
pub(super) struct AllocTable {
bitmap: Mutex<BitMap>,
next_avail: AtomicUsize,
nblocks: NonZeroUsize,
is_dirty: AtomicBool,
cvar: Condvar,
num_free: CvarMutex<usize>,
}
/// Per-TX block allocator in `SwornDisk`, recording validities
/// of user data blocks within each TX. All metadata will be stored in
/// `TxLog`s of bucket `BAL` during TX for durability and recovery purpose.
pub(super) struct BlockAlloc<D> {
alloc_table: Arc<AllocTable>, // Point to the global allocator
diff_table: Mutex<BTreeMap<Hba, AllocDiff>>, // Per-TX diffs of block validity
store: Arc<TxLogStore<D>>, // Store for diff log from L3
diff_log: Mutex<Option<Arc<TxLog<D>>>>, // Opened diff log (currently not in-use)
}
/// Incremental diff of block validity.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
enum AllocDiff {
Alloc = 3,
Dealloc = 7,
Invalid,
}
const DIFF_RECORD_SIZE: usize = size_of::<AllocDiff>() + size_of::<Hba>();
impl AllocTable {
/// Create a new `AllocTable` given the total number of blocks.
pub fn new(nblocks: NonZeroUsize) -> Self {
Self {
bitmap: Mutex::new(BitMap::repeat(true, nblocks.get())),
next_avail: AtomicUsize::new(0),
nblocks,
is_dirty: AtomicBool::new(false),
cvar: Condvar::new(),
num_free: CvarMutex::new(nblocks.get()),
}
}
/// Allocate a free slot for a new block, returns `None`
/// if there are no free slots.
pub fn alloc(&self) -> Option<Hba> {
let mut bitmap = self.bitmap.lock();
let next_avail = self.next_avail.load(Ordering::Acquire);
let hba = if let Some(hba) = bitmap.first_one(next_avail) {
hba
} else {
bitmap.first_one(0)?
};
bitmap.set(hba, false);
self.next_avail.store(hba + 1, Ordering::Release);
Some(hba as Hba)
}
/// Allocate multiple free slots for a bunch of new blocks, returns `None`
/// if there are no free slots for all.
pub fn alloc_batch(&self, count: NonZeroUsize) -> Result<Vec<Hba>> {
let cnt = count.get();
let mut num_free = self.num_free.lock().unwrap();
while *num_free < cnt {
// TODO: May not be woken, may require manual triggering of a compaction in L4
num_free = self.cvar.wait(num_free).unwrap();
}
debug_assert!(*num_free >= cnt);
let hbas = self.do_alloc_batch(count).unwrap();
debug_assert_eq!(hbas.len(), cnt);
*num_free -= cnt;
let _ = self
.is_dirty
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed);
Ok(hbas)
}
fn do_alloc_batch(&self, count: NonZeroUsize) -> Option<Vec<Hba>> {
let count = count.get();
debug_assert!(count > 0);
let mut bitmap = self.bitmap.lock();
let mut next_avail = self.next_avail.load(Ordering::Acquire);
if next_avail + count > self.nblocks.get() {
next_avail = bitmap.first_one(0)?;
}
let hbas = if let Some(hbas) = bitmap.first_ones(next_avail, count) {
hbas
} else {
next_avail = bitmap.first_one(0)?;
bitmap.first_ones(next_avail, count)?
};
hbas.iter().for_each(|hba| bitmap.set(*hba, false));
next_avail = hbas.last().unwrap() + 1;
self.next_avail.store(next_avail, Ordering::Release);
Some(hbas)
}
/// Recover the `AllocTable` from the latest `BVT` log and a bunch of `BAL` logs
/// in the given store.
pub fn recover<D: BlockSet + 'static>(
nblocks: NonZeroUsize,
store: &Arc<TxLogStore<D>>,
) -> Result<Self> {
let tx = store.new_tx();
let res: Result<_> = tx.context(|| {
// Recover the block validity table from `BVT` log first
let bvt_log_res = store.open_log_in(BUCKET_BLOCK_VALIDITY_TABLE);
let mut bitmap = match bvt_log_res {
Ok(bvt_log) => {
let mut buf = Buf::alloc(bvt_log.nblocks())?;
bvt_log.read(0 as BlockId, buf.as_mut())?;
postcard::from_bytes(buf.as_slice()).map_err(|_| {
Error::with_msg(InvalidArgs, "deserialize block validity table failed")
})?
}
Err(e) => {
if e.errno() != NotFound {
return Err(e);
}
BitMap::repeat(true, nblocks.get())
}
};
// Iterate each `BAL` log and apply each diff, from older to newer
let bal_log_ids_res = store.list_logs_in(BUCKET_BLOCK_ALLOC_LOG);
if let Err(e) = &bal_log_ids_res
&& e.errno() == NotFound
{
let next_avail = bitmap.first_one(0).unwrap_or(0);
let num_free = bitmap.count_ones();
return Ok(Self {
bitmap: Mutex::new(bitmap),
next_avail: AtomicUsize::new(next_avail),
nblocks,
is_dirty: AtomicBool::new(false),
cvar: Condvar::new(),
num_free: CvarMutex::new(num_free),
});
}
let mut bal_log_ids = bal_log_ids_res?;
bal_log_ids.sort();
for bal_log_id in bal_log_ids {
let bal_log_res = store.open_log(bal_log_id, false);
if let Err(e) = &bal_log_res
&& e.errno() == NotFound
{
continue;
}
let bal_log = bal_log_res?;
let log_nblocks = bal_log.nblocks();
let mut buf = Buf::alloc(log_nblocks)?;
bal_log.read(0 as BlockId, buf.as_mut())?;
let buf_slice = buf.as_slice();
let mut offset = 0;
while offset <= log_nblocks * BLOCK_SIZE - DIFF_RECORD_SIZE {
let diff = AllocDiff::from(buf_slice[offset]);
offset += 1;
if diff == AllocDiff::Invalid {
continue;
}
let bid = BlockId::from_bytes(&buf_slice[offset..offset + BID_SIZE]);
offset += BID_SIZE;
match diff {
AllocDiff::Alloc => bitmap.set(bid, false),
AllocDiff::Dealloc => bitmap.set(bid, true),
_ => unreachable!(),
}
}
}
let next_avail = bitmap.first_one(0).unwrap_or(0);
let num_free = bitmap.count_ones();
Ok(Self {
bitmap: Mutex::new(bitmap),
next_avail: AtomicUsize::new(next_avail),
nblocks,
is_dirty: AtomicBool::new(false),
cvar: Condvar::new(),
num_free: CvarMutex::new(num_free),
})
});
let recov_self = res.map_err(|_| {
tx.abort();
Error::with_msg(TxAborted, "recover block validity table TX aborted")
})?;
tx.commit()?;
Ok(recov_self)
}
/// Persist the block validity table to `BVT` log. GC all existed `BAL` logs.
pub fn do_compaction<D: BlockSet + 'static>(&self, store: &Arc<TxLogStore<D>>) -> Result<()> {
if !self.is_dirty.load(Ordering::Relaxed) {
return Ok(());
}
// Serialize the block validity table
let bitmap = self.bitmap.lock();
const BITMAP_MAX_SIZE: usize = 1792 * BLOCK_SIZE; // TBD
let mut ser_buf = vec![0; BITMAP_MAX_SIZE];
let ser_len = postcard::to_slice::<BitMap>(&bitmap, &mut ser_buf)
.map_err(|_| Error::with_msg(InvalidArgs, "serialize block validity table failed"))?
.len();
ser_buf.resize(align_up(ser_len, BLOCK_SIZE), 0);
drop(bitmap);
// Persist the serialized block validity table to `BVT` log
// and GC any old `BVT` logs and `BAL` logs
let tx = store.new_tx();
let res: Result<_> = tx.context(|| {
if let Ok(bvt_log_ids) = store.list_logs_in(BUCKET_BLOCK_VALIDITY_TABLE) {
for bvt_log_id in bvt_log_ids {
store.delete_log(bvt_log_id)?;
}
}
let bvt_log = store.create_log(BUCKET_BLOCK_VALIDITY_TABLE)?;
bvt_log.append(BufRef::try_from(&ser_buf[..]).unwrap())?;
if let Ok(bal_log_ids) = store.list_logs_in(BUCKET_BLOCK_ALLOC_LOG) {
for bal_log_id in bal_log_ids {
store.delete_log(bal_log_id)?;
}
}
Ok(())
});
if res.is_err() {
tx.abort();
return_errno_with_msg!(TxAborted, "persist block validity table TX aborted");
}
tx.commit()?;
self.is_dirty.store(false, Ordering::Relaxed);
Ok(())
}
/// Mark a specific slot deallocated.
pub fn set_deallocated(&self, nth: usize) {
let mut num_free = self.num_free.lock().unwrap();
self.bitmap.lock().set(nth, true);
*num_free += 1;
const AVG_ALLOC_COUNT: usize = 1024;
if *num_free >= AVG_ALLOC_COUNT {
self.cvar.notify_one();
}
}
}
impl<D: BlockSet + 'static> BlockAlloc<D> {
/// Create a new `BlockAlloc` with the given global allocator and store.
pub fn new(alloc_table: Arc<AllocTable>, store: Arc<TxLogStore<D>>) -> Self {
Self {
alloc_table,
diff_table: Mutex::new(BTreeMap::new()),
store,
diff_log: Mutex::new(None),
}
}
/// Record a diff of `Alloc`.
pub fn alloc_block(&self, block_id: Hba) -> Result<()> {
let mut diff_table = self.diff_table.lock();
let replaced = diff_table.insert(block_id, AllocDiff::Alloc);
debug_assert!(
replaced != Some(AllocDiff::Alloc),
"can't allocate a block twice"
);
Ok(())
}
/// Record a diff of `Dealloc`.
pub fn dealloc_block(&self, block_id: Hba) -> Result<()> {
let mut diff_table = self.diff_table.lock();
let replaced = diff_table.insert(block_id, AllocDiff::Dealloc);
debug_assert!(
replaced != Some(AllocDiff::Dealloc),
"can't deallocate a block twice"
);
Ok(())
}
/// Prepare the block validity diff log.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn prepare_diff_log(&self) -> Result<()> {
// Do nothing for now
Ok(())
}
/// Persist the metadata in diff table to the block validity diff log.
///
/// # Panics
///
/// This method must be called within a TX. Otherwise, this method panics.
pub fn update_diff_log(&self) -> Result<()> {
let diff_table = self.diff_table.lock();
if diff_table.is_empty() {
return Ok(());
}
let diff_log = self.store.create_log(BUCKET_BLOCK_ALLOC_LOG)?;
const MAX_BUF_SIZE: usize = 1024 * BLOCK_SIZE;
let mut diff_buf = Vec::with_capacity(MAX_BUF_SIZE);
for (block_id, block_diff) in diff_table.iter() {
diff_buf.push(*block_diff as u8);
diff_buf.extend_from_slice(block_id.as_bytes());
if diff_buf.len() + DIFF_RECORD_SIZE > MAX_BUF_SIZE {
diff_buf.resize(align_up(diff_buf.len(), BLOCK_SIZE), 0);
diff_log.append(BufRef::try_from(&diff_buf[..]).unwrap())?;
diff_buf.clear();
}
}
if diff_buf.is_empty() {
return Ok(());
}
diff_buf.resize(align_up(diff_buf.len(), BLOCK_SIZE), 0);
diff_log.append(BufRef::try_from(&diff_buf[..]).unwrap())
}
/// Update the metadata in diff table to the in-memory block validity table.
pub fn update_alloc_table(&self) {
let diff_table = self.diff_table.lock();
let alloc_table = &self.alloc_table;
let mut num_free = alloc_table.num_free.lock().unwrap();
let mut bitmap = alloc_table.bitmap.lock();
let mut num_dealloc = 0_usize;
for (block_id, block_diff) in diff_table.iter() {
match block_diff {
AllocDiff::Alloc => {
debug_assert!(!bitmap[*block_id]);
}
AllocDiff::Dealloc => {
debug_assert!(!bitmap[*block_id]);
bitmap.set(*block_id, true);
num_dealloc += 1;
}
AllocDiff::Invalid => unreachable!(),
};
}
*num_free += num_dealloc;
const AVG_ALLOC_COUNT: usize = 1024;
if *num_free >= AVG_ALLOC_COUNT {
alloc_table.cvar.notify_one();
}
}
}
impl From<u8> for AllocDiff {
fn from(value: u8) -> Self {
match value {
3 => AllocDiff::Alloc,
7 => AllocDiff::Dealloc,
_ => AllocDiff::Invalid,
}
}
}

View File

@ -0,0 +1,137 @@
// SPDX-License-Identifier: MPL-2.0
//! Data buffering.
use core::ops::RangeInclusive;
use super::sworndisk::RecordKey;
use crate::{
layers::bio::{BufMut, BufRef},
os::{BTreeMap, Condvar, CvarMutex, Mutex},
prelude::*,
};
/// A buffer to cache data blocks before they are written to disk.
#[derive(Debug)]
pub(super) struct DataBuf {
buf: Mutex<BTreeMap<RecordKey, Arc<DataBlock>>>,
cap: usize,
cvar: Condvar,
is_full: CvarMutex<bool>,
}
/// User data block.
pub(super) struct DataBlock([u8; BLOCK_SIZE]);
impl DataBuf {
/// Create a new empty data buffer with a given capacity.
pub fn new(cap: usize) -> Self {
Self {
buf: Mutex::new(BTreeMap::new()),
cap,
cvar: Condvar::new(),
is_full: CvarMutex::new(false),
}
}
/// Get the buffered data block with the key and copy
/// the content into `buf`.
pub fn get(&self, key: RecordKey, buf: &mut BufMut) -> Option<()> {
debug_assert_eq!(buf.nblocks(), 1);
if let Some(block) = self.buf.lock().get(&key) {
buf.as_mut_slice().copy_from_slice(block.as_slice());
Some(())
} else {
None
}
}
/// Get the buffered data blocks which keys are within the given range.
pub fn get_range(&self, range: RangeInclusive<RecordKey>) -> Vec<(RecordKey, Arc<DataBlock>)> {
self.buf
.lock()
.iter()
.filter_map(|(k, v)| {
if range.contains(k) {
Some((*k, v.clone()))
} else {
None
}
})
.collect()
}
/// Put the data block in `buf` into the buffer. Return
/// whether the buffer is full after insertion.
pub fn put(&self, key: RecordKey, buf: BufRef) -> bool {
debug_assert_eq!(buf.nblocks(), 1);
let mut is_full = self.is_full.lock().unwrap();
while *is_full {
is_full = self.cvar.wait(is_full).unwrap();
}
debug_assert!(!*is_full);
let mut data_buf = self.buf.lock();
let _ = data_buf.insert(key, DataBlock::from_buf(buf));
if data_buf.len() >= self.cap {
*is_full = true;
}
*is_full
}
/// Return the number of data blocks of the buffer.
pub fn nblocks(&self) -> usize {
self.buf.lock().len()
}
/// Return whether the buffer is full.
pub fn at_capacity(&self) -> bool {
self.nblocks() >= self.cap
}
/// Return whether the buffer is empty.
pub fn is_empty(&self) -> bool {
self.nblocks() == 0
}
/// Empty the buffer.
pub fn clear(&self) {
let mut is_full = self.is_full.lock().unwrap();
self.buf.lock().clear();
if *is_full {
*is_full = false;
self.cvar.notify_all();
}
}
/// Return all the buffered data blocks.
pub fn all_blocks(&self) -> Vec<(RecordKey, Arc<DataBlock>)> {
self.buf
.lock()
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect()
}
}
impl DataBlock {
/// Create a new data block from the given `buf`.
pub fn from_buf(buf: BufRef) -> Arc<Self> {
debug_assert_eq!(buf.nblocks(), 1);
Arc::new(DataBlock(buf.as_slice().try_into().unwrap()))
}
/// Return the immutable slice of the data block.
pub fn as_slice(&self) -> &[u8] {
&self.0
}
}
impl Debug for DataBlock {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DataBlock")
.field("first 16 bytes", &&self.0[..16])
.finish()
}
}

View File

@ -0,0 +1,41 @@
// SPDX-License-Identifier: MPL-2.0
//! The layer of secure virtual disk.
//!
//! `SwornDisk` provides three block I/O interfaces, `read()`, `write()` and `sync()`.
//! `SwornDisk` protects a logical block of user data using authenticated encryption.
//! The metadata of the encrypted logical blocks are inserted into a secure index `TxLsmTree`.
//!
//! `SwornDisk`'s backed untrusted host disk space is managed in `BlockAlloc`. Block reclamation can be
//! delayed to user-defined callbacks on `TxLsmTree`.
//! `SwornDisk` supports buffering written logical blocks.
//!
//! # Usage Example
//!
//! Write, sync then read blocks from `SwornDisk`.
//!
//! ```
//! let nblocks = 1024;
//! let mem_disk = MemDisk::create(nblocks)?;
//! let root_key = Key::random();
//! let sworndisk = SwornDisk::create(mem_disk.clone(), root_key)?;
//!
//! let num_rw = 128;
//! let mut rw_buf = Buf::alloc(1)?;
//! for i in 0..num_rw {
//! rw_buf.as_mut_slice().fill(i as u8);
//! sworndisk.write(i as Lba, rw_buf.as_ref())?;
//! }
//! sworndisk.sync()?;
//! for i in 0..num_rw {
//! sworndisk.read(i as Lba, rw_buf.as_mut())?;
//! assert_eq!(rw_buf.as_slice()[0], i as u8);
//! }
//! ```
mod bio;
mod block_alloc;
mod data_buf;
mod sworndisk;
pub use self::sworndisk::SwornDisk;

View File

@ -0,0 +1,881 @@
// SPDX-License-Identifier: MPL-2.0
//! SwornDisk as a block device.
//!
//! API: submit_bio(), submit_bio_sync(), create(), open(),
//! read(), readv(), write(), writev(), sync().
//!
//! Responsible for managing a `TxLsmTree`, whereas the TX logs (WAL and SSTs)
//! are stored; an untrusted disk storing user data, a `BlockAlloc` for managing data blocks'
//! allocation metadata. `TxLsmTree` and `BlockAlloc` are manipulated
//! based on internal transactions.
use core::{
num::NonZeroUsize,
ops::{Add, Sub},
sync::atomic::{AtomicBool, Ordering},
};
use ostd::mm::VmIo;
use ostd_pod::Pod;
use super::{
bio::{BioReq, BioReqQueue, BioResp, BioType},
block_alloc::{AllocTable, BlockAlloc},
data_buf::DataBuf,
};
use crate::{
layers::{
bio::{BlockId, BlockSet, Buf, BufMut, BufRef},
log::TxLogStore,
lsm::{
AsKV, LsmLevel, RangeQueryCtx, RecordKey as RecordK, RecordValue as RecordV,
SyncIdStore, TxEventListener, TxEventListenerFactory, TxLsmTree, TxType,
},
},
os::{Aead, AeadIv as Iv, AeadKey as Key, AeadMac as Mac, RwLock},
prelude::*,
tx::CurrentTx,
};
/// Logical Block Address.
pub type Lba = BlockId;
/// Host Block Address.
pub type Hba = BlockId;
/// SwornDisk.
pub struct SwornDisk<D: BlockSet> {
inner: Arc<DiskInner<D>>,
}
/// Inner structures of `SwornDisk`.
struct DiskInner<D: BlockSet> {
/// Block I/O request queue.
bio_req_queue: BioReqQueue,
/// A `TxLsmTree` to store metadata of the logical blocks.
logical_block_table: TxLsmTree<RecordKey, RecordValue, D>,
/// The underlying disk where user data is stored.
user_data_disk: D,
/// Manage space of the data disk.
block_validity_table: Arc<AllocTable>,
/// TX log store for managing logs in `TxLsmTree` and block alloc logs.
tx_log_store: Arc<TxLogStore<D>>,
/// A buffer to cache data blocks.
data_buf: DataBuf,
/// Root encryption key.
root_key: Key,
/// Whether `SwornDisk` is dropped.
is_dropped: AtomicBool,
/// Scope lock for control write and sync operation.
write_sync_region: RwLock<()>,
}
impl<D: BlockSet + 'static> aster_block::BlockDevice for SwornDisk<D> {
fn enqueue(
&self,
bio: aster_block::bio::SubmittedBio,
) -> core::result::Result<(), aster_block::bio::BioEnqueueError> {
use aster_block::bio::{BioStatus, BioType, SubmittedBio};
if bio.type_() == BioType::Discard {
warn!("discard operation not supported");
bio.complete(BioStatus::NotSupported);
return Ok(());
}
if bio.type_() == BioType::Flush {
let status = match self.sync() {
Ok(_) => BioStatus::Complete,
Err(_) => BioStatus::IoError,
};
bio.complete(status);
return Ok(());
}
let start_offset = bio.sid_range().start.to_offset();
let start_lba = start_offset / BLOCK_SIZE;
let end_offset = bio.sid_range().end.to_offset();
let end_lba = end_offset.div_ceil(BLOCK_SIZE);
let nblocks = end_lba - start_lba;
let Ok(buf) = Buf::alloc(nblocks) else {
bio.complete(BioStatus::NoSpace);
return Ok(());
};
let handle_read_bio = |mut buf: Buf| {
if self.read(start_lba, buf.as_mut()).is_err() {
return BioStatus::IoError;
}
let mut base = start_offset % BLOCK_SIZE;
bio.segments().iter().for_each(|seg| {
let offset = seg.nbytes();
let _ = seg.write_bytes(0, &buf.as_slice()[base..base + offset]);
base += offset;
});
BioStatus::Complete
};
let handle_write_bio = |mut buf: Buf| {
let mut base = start_offset % BLOCK_SIZE;
// Read the first unaligned block.
if base != 0 {
let buf_mut = BufMut::try_from(&mut buf.as_mut_slice()[..BLOCK_SIZE]).unwrap();
if self.read(start_lba, buf_mut).is_err() {
return BioStatus::IoError;
}
}
// Read the last unaligned block.
if end_offset % BLOCK_SIZE != 0 {
let offset = buf.as_slice().len() - BLOCK_SIZE;
let buf_mut = BufMut::try_from(&mut buf.as_mut_slice()[offset..]).unwrap();
if self.read(end_lba - 1, buf_mut).is_err() {
return BioStatus::IoError;
}
}
bio.segments().iter().for_each(|seg| {
let offset = seg.nbytes();
let _ = seg.read_bytes(0, &mut buf.as_mut_slice()[base..base + offset]);
base += offset;
});
if self.write(start_lba, buf.as_ref()).is_err() {
return BioStatus::IoError;
}
BioStatus::Complete
};
let status = match bio.type_() {
BioType::Read => handle_read_bio(buf),
BioType::Write => handle_write_bio(buf),
_ => BioStatus::NotSupported,
};
bio.complete(status);
Ok(())
}
fn metadata(&self) -> aster_block::BlockDeviceMeta {
use aster_block::{BlockDeviceMeta, BLOCK_SIZE, SECTOR_SIZE};
BlockDeviceMeta {
max_nr_segments_per_bio: usize::MAX,
nr_sectors: (BLOCK_SIZE / SECTOR_SIZE) * self.total_blocks(),
}
}
}
impl<D: BlockSet + 'static> SwornDisk<D> {
/// Read a specified number of blocks at a logical block address on the device.
/// The block contents will be read into a single contiguous buffer.
pub fn read(&self, lba: Lba, buf: BufMut) -> Result<()> {
self.check_rw_args(lba, buf.nblocks())?;
self.inner.read(lba, buf)
}
/// Read multiple blocks at a logical block address on the device.
/// The block contents will be read into several scattered buffers.
pub fn readv<'a>(&self, lba: Lba, bufs: &'a mut [BufMut<'a>]) -> Result<()> {
self.check_rw_args(lba, bufs.iter().fold(0, |acc, buf| acc + buf.nblocks()))?;
self.inner.readv(lba, bufs)
}
/// Write a specified number of blocks at a logical block address on the device.
/// The block contents reside in a single contiguous buffer.
pub fn write(&self, lba: Lba, buf: BufRef) -> Result<()> {
self.check_rw_args(lba, buf.nblocks())?;
let _rguard = self.inner.write_sync_region.read();
self.inner.write(lba, buf)
}
/// Write multiple blocks at a logical block address on the device.
/// The block contents reside in several scattered buffers.
pub fn writev(&self, lba: Lba, bufs: &[BufRef]) -> Result<()> {
self.check_rw_args(lba, bufs.iter().fold(0, |acc, buf| acc + buf.nblocks()))?;
let _rguard = self.inner.write_sync_region.read();
self.inner.writev(lba, bufs)
}
/// Sync all cached data in the device to the storage medium for durability.
pub fn sync(&self) -> Result<()> {
let _wguard = self.inner.write_sync_region.write();
// TODO: Error handling the sync operation
self.inner.sync().unwrap();
trace!("[SwornDisk] Sync completed. {self:?}");
Ok(())
}
/// Returns the total number of blocks in the device.
pub fn total_blocks(&self) -> usize {
self.inner.user_data_disk.nblocks()
}
/// Creates a new `SwornDisk` on the given disk, with the root encryption key.
pub fn create(
disk: D,
root_key: Key,
sync_id_store: Option<Arc<dyn SyncIdStore>>,
) -> Result<Self> {
let data_disk = Self::subdisk_for_data(&disk)?;
let lsm_tree_disk = Self::subdisk_for_logical_block_table(&disk)?;
let tx_log_store = Arc::new(TxLogStore::format(lsm_tree_disk, root_key)?);
let block_validity_table = Arc::new(AllocTable::new(
NonZeroUsize::new(data_disk.nblocks()).unwrap(),
));
let listener_factory = Arc::new(TxLsmTreeListenerFactory::new(
tx_log_store.clone(),
block_validity_table.clone(),
));
let logical_block_table = {
let table = block_validity_table.clone();
let on_drop_record_in_memtable = move |record: &dyn AsKV<RecordKey, RecordValue>| {
// Deallocate the host block while the corresponding record is dropped in `MemTable`
table.set_deallocated(record.value().hba);
};
TxLsmTree::format(
tx_log_store.clone(),
listener_factory,
Some(Arc::new(on_drop_record_in_memtable)),
sync_id_store,
)?
};
let new_self = Self {
inner: Arc::new(DiskInner {
bio_req_queue: BioReqQueue::new(),
logical_block_table,
user_data_disk: data_disk,
block_validity_table,
tx_log_store,
data_buf: DataBuf::new(DATA_BUF_CAP),
root_key,
is_dropped: AtomicBool::new(false),
write_sync_region: RwLock::new(()),
}),
};
info!("[SwornDisk] Created successfully! {:?}", &new_self);
// XXX: Would `disk::drop()` bring unexpected behavior?
Ok(new_self)
}
/// Opens the `SwornDisk` on the given disk, with the root encryption key.
pub fn open(
disk: D,
root_key: Key,
sync_id_store: Option<Arc<dyn SyncIdStore>>,
) -> Result<Self> {
let data_disk = Self::subdisk_for_data(&disk)?;
let lsm_tree_disk = Self::subdisk_for_logical_block_table(&disk)?;
let tx_log_store = Arc::new(TxLogStore::recover(lsm_tree_disk, root_key)?);
let block_validity_table = Arc::new(AllocTable::recover(
NonZeroUsize::new(data_disk.nblocks()).unwrap(),
&tx_log_store,
)?);
let listener_factory = Arc::new(TxLsmTreeListenerFactory::new(
tx_log_store.clone(),
block_validity_table.clone(),
));
let logical_block_table = {
let table = block_validity_table.clone();
let on_drop_record_in_memtable = move |record: &dyn AsKV<RecordKey, RecordValue>| {
// Deallocate the host block while the corresponding record is dropped in `MemTable`
table.set_deallocated(record.value().hba);
};
TxLsmTree::recover(
tx_log_store.clone(),
listener_factory,
Some(Arc::new(on_drop_record_in_memtable)),
sync_id_store,
)?
};
let opened_self = Self {
inner: Arc::new(DiskInner {
bio_req_queue: BioReqQueue::new(),
logical_block_table,
user_data_disk: data_disk,
block_validity_table,
data_buf: DataBuf::new(DATA_BUF_CAP),
tx_log_store,
root_key,
is_dropped: AtomicBool::new(false),
write_sync_region: RwLock::new(()),
}),
};
info!("[SwornDisk] Opened successfully! {:?}", &opened_self);
Ok(opened_self)
}
/// Submit a new block I/O request and wait its completion (Synchronous).
pub fn submit_bio_sync(&self, bio_req: BioReq) -> BioResp {
bio_req.submit();
self.inner.handle_bio_req(&bio_req)
}
// TODO: Support handling request asynchronously
/// Check whether the arguments are valid for read/write operations.
fn check_rw_args(&self, lba: Lba, buf_nblocks: usize) -> Result<()> {
if lba + buf_nblocks > self.inner.user_data_disk.nblocks() {
Err(Error::with_msg(
OutOfDisk,
"read/write out of disk capacity",
))
} else {
Ok(())
}
}
fn subdisk_for_data(disk: &D) -> Result<D> {
disk.subset(0..disk.nblocks() * 15 / 16) // TBD
}
fn subdisk_for_logical_block_table(disk: &D) -> Result<D> {
disk.subset(disk.nblocks() * 15 / 16..disk.nblocks()) // TBD
}
}
/// Capacity of the user data blocks buffer.
const DATA_BUF_CAP: usize = 1024;
impl<D: BlockSet + 'static> DiskInner<D> {
/// Read a specified number of blocks at a logical block address on the device.
/// The block contents will be read into a single contiguous buffer.
pub fn read(&self, lba: Lba, buf: BufMut) -> Result<()> {
let nblocks = buf.nblocks();
let res = if nblocks == 1 {
self.read_one_block(lba, buf)
} else {
self.read_multi_blocks(lba, &mut [buf])
};
// Allow empty read
if let Err(e) = &res
&& e.errno() == NotFound
{
warn!("[SwornDisk] read contains empty read on lba {lba}");
return Ok(());
}
res
}
/// Read multiple blocks at a logical block address on the device.
/// The block contents will be read into several scattered buffers.
pub fn readv<'a>(&self, lba: Lba, bufs: &'a mut [BufMut<'a>]) -> Result<()> {
let res = self.read_multi_blocks(lba, bufs);
// Allow empty read
if let Err(e) = &res
&& e.errno() == NotFound
{
warn!("[SwornDisk] readv contains empty read on lba {lba}");
return Ok(());
}
res
}
fn read_one_block(&self, lba: Lba, mut buf: BufMut) -> Result<()> {
debug_assert_eq!(buf.nblocks(), 1);
// Search in `DataBuf` first
if self.data_buf.get(RecordKey { lba }, &mut buf).is_some() {
return Ok(());
}
// Search in `TxLsmTree` then
let value = self.logical_block_table.get(&RecordKey { lba })?;
// Perform disk read and decryption
let mut cipher = Buf::alloc(1)?;
self.user_data_disk.read(value.hba, cipher.as_mut())?;
Aead::new().decrypt(
cipher.as_slice(),
&value.key,
&Iv::new_zeroed(),
&[],
&value.mac,
buf.as_mut_slice(),
)?;
Ok(())
}
fn read_multi_blocks<'a>(&self, lba: Lba, bufs: &'a mut [BufMut<'a>]) -> Result<()> {
let mut buf_vec = BufMutVec::from_bufs(bufs);
let nblocks = buf_vec.nblocks();
let mut range_query_ctx =
RangeQueryCtx::<RecordKey, RecordValue>::new(RecordKey { lba }, nblocks);
// Search in `DataBuf` first
for (key, data_block) in self
.data_buf
.get_range(range_query_ctx.range_uncompleted().unwrap())
{
buf_vec
.nth_buf_mut_slice(key.lba - lba)
.copy_from_slice(data_block.as_slice());
range_query_ctx.mark_completed(key);
}
if range_query_ctx.is_completed() {
return Ok(());
}
// Search in `TxLsmTree` then
self.logical_block_table.get_range(&mut range_query_ctx)?;
// Allow empty read
debug_assert!(range_query_ctx.is_completed());
let mut res = range_query_ctx.into_results();
let record_batches = {
res.sort_by(|(_, v1), (_, v2)| v1.hba.cmp(&v2.hba));
res.chunk_by(|(_, v1), (_, v2)| v2.hba - v1.hba == 1)
};
// Perform disk read in batches and decryption
let mut cipher_buf = Buf::alloc(nblocks)?;
let cipher_slice = cipher_buf.as_mut_slice();
for record_batch in record_batches {
self.user_data_disk.read(
record_batch.first().unwrap().1.hba,
BufMut::try_from(&mut cipher_slice[..record_batch.len() * BLOCK_SIZE]).unwrap(),
)?;
for (nth, (key, value)) in record_batch.iter().enumerate() {
Aead::new().decrypt(
&cipher_slice[nth * BLOCK_SIZE..(nth + 1) * BLOCK_SIZE],
&value.key,
&Iv::new_zeroed(),
&[],
&value.mac,
buf_vec.nth_buf_mut_slice(key.lba - lba),
)?;
}
}
Ok(())
}
/// Write a specified number of blocks at a logical block address on the device.
/// The block contents reside in a single contiguous buffer.
pub fn write(&self, mut lba: Lba, buf: BufRef) -> Result<()> {
// Write block contents to `DataBuf` directly
for block_buf in buf.iter() {
let buf_at_capacity = self.data_buf.put(RecordKey { lba }, block_buf);
// Flush all data blocks in `DataBuf` to disk if it's full
if buf_at_capacity {
// TODO: Error handling: Should discard current write in `DataBuf`
self.flush_data_buf()?;
}
lba += 1;
}
Ok(())
}
/// Write multiple blocks at a logical block address on the device.
/// The block contents reside in several scattered buffers.
pub fn writev(&self, mut lba: Lba, bufs: &[BufRef]) -> Result<()> {
for buf in bufs {
self.write(lba, *buf)?;
lba += buf.nblocks();
}
Ok(())
}
fn flush_data_buf(&self) -> Result<()> {
let records = self.write_blocks_from_data_buf()?;
// Insert new records of data blocks to `TxLsmTree`
for (key, value) in records {
// TODO: Error handling: Should dealloc the written blocks
self.logical_block_table.put(key, value)?;
}
self.data_buf.clear();
Ok(())
}
fn write_blocks_from_data_buf(&self) -> Result<Vec<(RecordKey, RecordValue)>> {
let data_blocks = self.data_buf.all_blocks();
let num_write = data_blocks.len();
let mut records = Vec::with_capacity(num_write);
if num_write == 0 {
return Ok(records);
}
// Allocate slots for data blocks
let hbas = self
.block_validity_table
.alloc_batch(NonZeroUsize::new(num_write).unwrap())?;
debug_assert_eq!(hbas.len(), num_write);
let hba_batches = hbas.chunk_by(|hba1, hba2| hba2 - hba1 == 1);
// Perform encryption and batch disk write
let mut cipher_buf = Buf::alloc(num_write)?;
let mut cipher_slice = cipher_buf.as_mut_slice();
let mut nth = 0;
for hba_batch in hba_batches {
for (i, &hba) in hba_batch.iter().enumerate() {
let (lba, data_block) = &data_blocks[nth];
let key = Key::random();
let mac = Aead::new().encrypt(
data_block.as_slice(),
&key,
&Iv::new_zeroed(),
&[],
&mut cipher_slice[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE],
)?;
records.push((*lba, RecordValue { hba, key, mac }));
nth += 1;
}
self.user_data_disk.write(
*hba_batch.first().unwrap(),
BufRef::try_from(&cipher_slice[..hba_batch.len() * BLOCK_SIZE]).unwrap(),
)?;
cipher_slice = &mut cipher_slice[hba_batch.len() * BLOCK_SIZE..];
}
Ok(records)
}
/// Sync all cached data in the device to the storage medium for durability.
pub fn sync(&self) -> Result<()> {
self.flush_data_buf()?;
debug_assert!(self.data_buf.is_empty());
self.logical_block_table.sync()?;
// XXX: May impact performance when there comes frequent syncs
self.block_validity_table
.do_compaction(&self.tx_log_store)?;
self.tx_log_store.sync()?;
self.user_data_disk.flush()
}
/// Handle one block I/O request. Mark the request completed when finished,
/// return any error that occurs.
pub fn handle_bio_req(&self, req: &BioReq) -> BioResp {
let res = match req.type_() {
BioType::Read => self.do_read(req),
BioType::Write => self.do_write(req),
BioType::Sync => self.do_sync(req),
};
req.complete(res.clone());
res
}
/// Handle a read I/O request.
fn do_read(&self, req: &BioReq) -> BioResp {
debug_assert_eq!(req.type_(), BioType::Read);
let lba = req.addr() as Lba;
let mut req_bufs = req.take_bufs();
let mut bufs = {
let mut bufs = Vec::with_capacity(req.nbufs());
for buf in req_bufs.iter_mut() {
bufs.push(BufMut::try_from(buf.as_mut_slice())?);
}
bufs
};
if bufs.len() == 1 {
let buf = bufs.remove(0);
return self.read(lba, buf);
}
self.readv(lba, &mut bufs)
}
/// Handle a write I/O request.
fn do_write(&self, req: &BioReq) -> BioResp {
debug_assert_eq!(req.type_(), BioType::Write);
let lba = req.addr() as Lba;
let req_bufs = req.take_bufs();
let bufs = {
let mut bufs = Vec::with_capacity(req.nbufs());
for buf in req_bufs.iter() {
bufs.push(BufRef::try_from(buf.as_slice())?);
}
bufs
};
self.writev(lba, &bufs)
}
/// Handle a sync I/O request.
fn do_sync(&self, req: &BioReq) -> BioResp {
debug_assert_eq!(req.type_(), BioType::Sync);
self.sync()
}
}
impl<D: BlockSet> Drop for SwornDisk<D> {
fn drop(&mut self) {
self.inner.is_dropped.store(true, Ordering::Release);
}
}
impl<D: BlockSet + 'static> Debug for SwornDisk<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SwornDisk")
.field("user_data_nblocks", &self.inner.user_data_disk.nblocks())
.field("logical_block_table", &self.inner.logical_block_table)
.finish()
}
}
/// A wrapper for `[BufMut]` used in `readv()`.
struct BufMutVec<'a> {
bufs: &'a mut [BufMut<'a>],
nblocks: usize,
}
impl<'a> BufMutVec<'a> {
pub fn from_bufs(bufs: &'a mut [BufMut<'a>]) -> Self {
debug_assert!(!bufs.is_empty());
let nblocks = bufs
.iter()
.map(|buf| buf.nblocks())
.fold(0_usize, |sum, nblocks| sum.saturating_add(nblocks));
Self { bufs, nblocks }
}
pub fn nblocks(&self) -> usize {
self.nblocks
}
pub fn nth_buf_mut_slice(&mut self, mut nth: usize) -> &mut [u8] {
debug_assert!(nth < self.nblocks);
for buf in self.bufs.iter_mut() {
let nblocks = buf.nblocks();
if nth >= buf.nblocks() {
nth -= nblocks;
} else {
return &mut buf.as_mut_slice()[nth * BLOCK_SIZE..(nth + 1) * BLOCK_SIZE];
}
}
&mut []
}
}
/// Listener factory for `TxLsmTree`.
struct TxLsmTreeListenerFactory<D> {
store: Arc<TxLogStore<D>>,
alloc_table: Arc<AllocTable>,
}
impl<D> TxLsmTreeListenerFactory<D> {
fn new(store: Arc<TxLogStore<D>>, alloc_table: Arc<AllocTable>) -> Self {
Self { store, alloc_table }
}
}
impl<D: BlockSet + 'static> TxEventListenerFactory<RecordKey, RecordValue>
for TxLsmTreeListenerFactory<D>
{
fn new_event_listener(
&self,
tx_type: TxType,
) -> Arc<dyn TxEventListener<RecordKey, RecordValue>> {
Arc::new(TxLsmTreeListener::new(
tx_type,
Arc::new(BlockAlloc::new(
self.alloc_table.clone(),
self.store.clone(),
)),
))
}
}
/// Event listener for `TxLsmTree`.
struct TxLsmTreeListener<D> {
tx_type: TxType,
block_alloc: Arc<BlockAlloc<D>>,
}
impl<D> TxLsmTreeListener<D> {
fn new(tx_type: TxType, block_alloc: Arc<BlockAlloc<D>>) -> Self {
Self {
tx_type,
block_alloc,
}
}
}
/// Register callbacks for different TXs in `TxLsmTree`.
impl<D: BlockSet + 'static> TxEventListener<RecordKey, RecordValue> for TxLsmTreeListener<D> {
fn on_add_record(&self, record: &dyn AsKV<RecordKey, RecordValue>) -> Result<()> {
match self.tx_type {
TxType::Compaction {
to_level: LsmLevel::L0,
} => self.block_alloc.alloc_block(record.value().hba),
// Major Compaction TX and Migration TX do not add new records
TxType::Compaction { .. } | TxType::Migration => {
// Do nothing
Ok(())
}
}
}
fn on_drop_record(&self, record: &dyn AsKV<RecordKey, RecordValue>) -> Result<()> {
match self.tx_type {
// Minor Compaction TX doesn't compact records
TxType::Compaction {
to_level: LsmLevel::L0,
} => {
unreachable!();
}
TxType::Compaction { .. } | TxType::Migration => {
self.block_alloc.dealloc_block(record.value().hba)
}
}
}
fn on_tx_begin(&self, tx: &mut CurrentTx<'_>) -> Result<()> {
match self.tx_type {
TxType::Compaction { .. } | TxType::Migration => {
tx.context(|| self.block_alloc.prepare_diff_log().unwrap())
}
}
Ok(())
}
fn on_tx_precommit(&self, tx: &mut CurrentTx<'_>) -> Result<()> {
match self.tx_type {
TxType::Compaction { .. } | TxType::Migration => {
tx.context(|| self.block_alloc.update_diff_log().unwrap())
}
}
Ok(())
}
fn on_tx_commit(&self) {
match self.tx_type {
TxType::Compaction { .. } | TxType::Migration => self.block_alloc.update_alloc_table(),
}
}
}
/// Key-Value record for `TxLsmTree`.
pub(super) struct Record {
key: RecordKey,
value: RecordValue,
}
/// The key of a `Record`.
#[repr(C)]
#[derive(Clone, Copy, Pod, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub(super) struct RecordKey {
/// Logical block address of user data block.
pub lba: Lba,
}
/// The value of a `Record`.
#[repr(C)]
#[derive(Clone, Copy, Pod, Debug)]
pub(super) struct RecordValue {
/// Host block address of user data block.
pub hba: Hba,
/// Encryption key of the data block.
pub key: Key,
/// Encrypted MAC of the data block.
pub mac: Mac,
}
impl Add<usize> for RecordKey {
type Output = Self;
fn add(self, other: usize) -> Self::Output {
Self {
lba: self.lba + other,
}
}
}
impl Sub<RecordKey> for RecordKey {
type Output = usize;
fn sub(self, other: RecordKey) -> Self::Output {
self.lba - other.lba
}
}
impl RecordK<RecordKey> for RecordKey {}
impl RecordV for RecordValue {}
impl AsKV<RecordKey, RecordValue> for Record {
fn key(&self) -> &RecordKey {
&self.key
}
fn value(&self) -> &RecordValue {
&self.value
}
}
#[cfg(test)]
mod tests {
use core::ptr::NonNull;
use std::thread;
use super::*;
use crate::layers::{bio::MemDisk, disk::bio::BioReqBuilder};
#[test]
fn sworndisk_fns() -> Result<()> {
let nblocks = 64 * 1024;
let mem_disk = MemDisk::create(nblocks)?;
let root_key = Key::random();
// Create a new `SwornDisk` then do some writes
let sworndisk = SwornDisk::create(mem_disk.clone(), root_key, None)?;
let num_rw = 1024;
// Submit a write block I/O request
let mut bufs = Vec::with_capacity(num_rw);
(0..num_rw).for_each(|i| {
let mut buf = Buf::alloc(1).unwrap();
buf.as_mut_slice().fill(i as u8);
bufs.push(buf);
});
let bio_req = BioReqBuilder::new(BioType::Write)
.addr(0 as BlockId)
.bufs(bufs)
.build();
sworndisk.submit_bio_sync(bio_req)?;
// Sync the `SwornDisk` then do some reads
sworndisk.submit_bio_sync(BioReqBuilder::new(BioType::Sync).build())?;
let mut rbuf = Buf::alloc(1)?;
for i in 0..num_rw {
sworndisk.read(i as Lba, rbuf.as_mut())?;
assert_eq!(rbuf.as_slice()[0], i as u8);
}
// Open the closed `SwornDisk` then test its data's existence
drop(sworndisk);
thread::spawn(move || -> Result<()> {
let opened_sworndisk = SwornDisk::open(mem_disk, root_key, None)?;
let mut rbuf = Buf::alloc(2)?;
opened_sworndisk.read(5 as Lba, rbuf.as_mut())?;
assert_eq!(rbuf.as_slice()[0], 5u8);
assert_eq!(rbuf.as_slice()[4096], 6u8);
Ok(())
})
.join()
.unwrap()
}
}

View File

@ -0,0 +1,14 @@
// SPDX-License-Identifier: MPL-2.0
#[path = "0-bio/mod.rs"]
pub mod bio;
#[path = "1-crypto/mod.rs"]
pub mod crypto;
#[path = "5-disk/mod.rs"]
pub mod disk;
#[path = "2-edit/mod.rs"]
pub mod edit;
#[path = "3-log/mod.rs"]
pub mod log;
#[path = "4-lsm/mod.rs"]
pub mod lsm;

View File

@ -0,0 +1,27 @@
// SPDX-License-Identifier: MPL-2.0
#![no_std]
#![deny(unsafe_code)]
#![feature(let_chains)]
#![feature(negative_impls)]
#![feature(slice_as_chunks)]
#![allow(dead_code, unused_imports)]
mod error;
mod layers;
mod os;
mod prelude;
mod tx;
mod util;
extern crate alloc;
pub use self::{
error::{Errno, Error},
layers::{
bio::{BlockId, BlockSet, Buf, BufMut, BufRef, BLOCK_SIZE},
disk::SwornDisk,
},
os::{Aead, AeadIv, AeadKey, AeadMac, Rng},
util::{Aead as _, RandomInit, Rng as _},
};

View File

@ -0,0 +1,404 @@
// SPDX-License-Identifier: MPL-2.0
//! OS-specific or OS-dependent APIs.
pub use alloc::{
boxed::Box,
collections::BTreeMap,
string::{String, ToString},
sync::{Arc, Weak},
vec::Vec,
};
use core::{
fmt,
sync::atomic::{AtomicBool, Ordering},
};
use aes_gcm::{
aead::{AeadInPlace, Key, NewAead, Nonce, Tag},
aes::Aes128,
Aes128Gcm,
};
use ctr::cipher::{NewCipher, StreamCipher};
pub use hashbrown::{HashMap, HashSet};
pub use ostd::sync::{Mutex, MutexGuard, RwLock, SpinLock};
use ostd::{
arch::read_random,
sync::{self, PreemptDisabled, WaitQueue},
task::{Task, TaskOptions},
};
use ostd_pod::Pod;
use serde::{Deserialize, Serialize};
use crate::{
error::{Errno, Error},
prelude::Result,
};
pub type RwLockReadGuard<'a, T> = sync::RwLockReadGuard<'a, T, PreemptDisabled>;
pub type RwLockWriteGuard<'a, T> = sync::RwLockWriteGuard<'a, T, PreemptDisabled>;
pub type SpinLockGuard<'a, T> = sync::SpinLockGuard<'a, T, PreemptDisabled>;
pub type Tid = u32;
/// A struct to get a unique identifier for the current thread.
pub struct CurrentThread;
impl CurrentThread {
/// Returns the Tid of current kernel thread.
pub fn id() -> Tid {
let Some(task) = Task::current() else {
return 0;
};
task.data() as *const _ as u32
}
}
/// A `Condvar` (Condition Variable) is a synchronization primitive that can block threads
/// until a certain condition becomes true.
///
/// This is a copy from `aster-nix`.
pub struct Condvar {
waitqueue: Arc<WaitQueue>,
counter: SpinLock<Inner>,
}
struct Inner {
waiter_count: u64,
notify_count: u64,
}
impl Condvar {
/// Creates a new condition variable.
pub fn new() -> Self {
Condvar {
waitqueue: Arc::new(WaitQueue::new()),
counter: SpinLock::new(Inner {
waiter_count: 0,
notify_count: 0,
}),
}
}
/// Atomically releases the given `MutexGuard`,
/// blocking the current thread until the condition variable
/// is notified, after which the mutex will be reacquired.
///
/// Returns a new `MutexGuard` if the operation is successful,
/// or returns the provided guard
/// within a `LockErr` if the waiting operation fails.
pub fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> Result<MutexGuard<'a, T>> {
let cond = || {
// Check if the notify counter is greater than 0.
let mut counter = self.counter.lock();
if counter.notify_count > 0 {
// Decrement the notify counter.
counter.notify_count -= 1;
Some(())
} else {
None
}
};
{
let mut counter = self.counter.lock();
counter.waiter_count += 1;
}
let lock = MutexGuard::get_lock(&guard);
drop(guard);
self.waitqueue.wait_until(cond);
Ok(lock.lock())
}
/// Wakes up one blocked thread waiting on this condition variable.
///
/// If there is a waiting thread, it will be unblocked
/// and allowed to reacquire the associated mutex.
/// If no threads are waiting, this function is a no-op.
pub fn notify_one(&self) {
let mut counter = self.counter.lock();
if counter.waiter_count == 0 {
return;
}
counter.notify_count += 1;
self.waitqueue.wake_one();
counter.waiter_count -= 1;
}
/// Wakes up all blocked threads waiting on this condition variable.
///
/// This method will unblock all waiting threads
/// and they will be allowed to reacquire the associated mutex.
/// If no threads are waiting, this function is a no-op.
pub fn notify_all(&self) {
let mut counter = self.counter.lock();
if counter.waiter_count == 0 {
return;
}
counter.notify_count = counter.waiter_count;
self.waitqueue.wake_all();
counter.waiter_count = 0;
}
}
impl fmt::Debug for Condvar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Condvar").finish_non_exhaustive()
}
}
/// Wrap the `Mutex` provided by kernel, used for `Condvar`.
#[repr(transparent)]
pub struct CvarMutex<T> {
inner: Mutex<T>,
}
// TODO: add distinguish guard type for `CvarMutex` if needed.
impl<T> CvarMutex<T> {
/// Constructs a new `Mutex` lock, using the kernel's `struct mutex`.
pub fn new(t: T) -> Self {
Self {
inner: Mutex::new(t),
}
}
/// Acquires the lock and gives the caller access to the data protected by it.
pub fn lock(&self) -> Result<MutexGuard<'_, T>> {
let guard = self.inner.lock();
Ok(guard)
}
}
impl<T: fmt::Debug> fmt::Debug for CvarMutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("No data, since `CvarMutex` does't support `try_lock` now")
}
}
/// Spawns a new thread, returning a `JoinHandle` for it.
pub fn spawn<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T + Send + Sync + 'static,
T: Send + 'static,
{
let is_finished = Arc::new(AtomicBool::new(false));
let data = Arc::new(SpinLock::new(None));
let is_finished_clone = is_finished.clone();
let data_clone = data.clone();
let task = TaskOptions::new(move || {
let data = f();
*data_clone.lock() = Some(data);
is_finished_clone.store(true, Ordering::Release);
})
.spawn()
.unwrap();
JoinHandle {
task,
is_finished,
data,
}
}
/// An owned permission to join on a thread (block on its termination).
///
/// This struct is created by the `spawn` function.
pub struct JoinHandle<T> {
task: Arc<Task>,
is_finished: Arc<AtomicBool>,
data: Arc<SpinLock<Option<T>>>,
}
impl<T> JoinHandle<T> {
/// Checks if the associated thread has finished running its main function.
pub fn is_finished(&self) -> bool {
self.is_finished.load(Ordering::Acquire)
}
/// Waits for the associated thread to finish.
pub fn join(self) -> Result<T> {
while !self.is_finished() {
Task::yield_now();
}
let data = self.data.lock().take().unwrap();
Ok(data)
}
}
impl<T> fmt::Debug for JoinHandle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinHandle").finish_non_exhaustive()
}
}
/// A random number generator.
pub struct Rng;
impl crate::util::Rng for Rng {
fn new(_seed: &[u8]) -> Self {
Self
}
fn fill_bytes(&self, dest: &mut [u8]) -> Result<()> {
let (chunks, remain) = dest.as_chunks_mut::<8>();
chunks.iter_mut().for_each(|chunk| {
chunk.copy_from_slice(read_random().unwrap_or(0u64).as_bytes());
});
remain.copy_from_slice(&read_random().unwrap_or(0u64).as_bytes()[..remain.len()]);
Ok(())
}
}
/// A macro to define byte_array_types used by `Aead` or `Skcipher`.
macro_rules! new_byte_array_type {
($name:ident, $n:expr) => {
#[repr(C)]
#[derive(Copy, Clone, Pod, Debug, Default, Deserialize, Serialize)]
pub struct $name([u8; $n]);
impl core::ops::Deref for $name {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.0.as_slice()
}
}
impl core::ops::DerefMut for $name {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut_slice()
}
}
impl crate::util::RandomInit for $name {
fn random() -> Self {
use crate::util::Rng;
let mut result = Self::default();
let rng = self::Rng::new(&[]);
rng.fill_bytes(&mut result).unwrap_or_default();
result
}
}
};
}
const AES_GCM_KEY_SIZE: usize = 16;
const AES_GCM_IV_SIZE: usize = 12;
const AES_GCM_MAC_SIZE: usize = 16;
new_byte_array_type!(AeadKey, AES_GCM_KEY_SIZE);
new_byte_array_type!(AeadIv, AES_GCM_IV_SIZE);
new_byte_array_type!(AeadMac, AES_GCM_MAC_SIZE);
/// An `AEAD` cipher.
#[derive(Debug, Default)]
pub struct Aead;
impl Aead {
/// Construct an `Aead` instance.
pub fn new() -> Self {
Self
}
}
impl crate::util::Aead for Aead {
type Key = AeadKey;
type Iv = AeadIv;
type Mac = AeadMac;
fn encrypt(
&self,
input: &[u8],
key: &AeadKey,
iv: &AeadIv,
aad: &[u8],
output: &mut [u8],
) -> Result<AeadMac> {
let key = Key::<Aes128Gcm>::from_slice(key);
let nonce = Nonce::<Aes128Gcm>::from_slice(iv);
let cipher = Aes128Gcm::new(key);
output.copy_from_slice(input);
let tag = cipher
.encrypt_in_place_detached(nonce, aad, output)
.map_err(|_| Error::with_msg(Errno::EncryptFailed, "aes-128-gcm encryption failed"))?;
let mut aead_mac = AeadMac::new_zeroed();
aead_mac.copy_from_slice(&tag);
Ok(aead_mac)
}
fn decrypt(
&self,
input: &[u8],
key: &AeadKey,
iv: &AeadIv,
aad: &[u8],
mac: &AeadMac,
output: &mut [u8],
) -> Result<()> {
let key = Key::<Aes128Gcm>::from_slice(key);
let nonce = Nonce::<Aes128Gcm>::from_slice(iv);
let tag = Tag::<Aes128Gcm>::from_slice(mac);
let cipher = Aes128Gcm::new(key);
output.copy_from_slice(input);
cipher
.decrypt_in_place_detached(nonce, aad, output, tag)
.map_err(|_| Error::with_msg(Errno::DecryptFailed, "aes-128-gcm decryption failed"))
}
}
type Aes128Ctr = ctr::Ctr128LE<Aes128>;
const AES_CTR_KEY_SIZE: usize = 16;
const AES_CTR_IV_SIZE: usize = 16;
new_byte_array_type!(SkcipherKey, AES_CTR_KEY_SIZE);
new_byte_array_type!(SkcipherIv, AES_CTR_IV_SIZE);
/// A symmetric key cipher.
#[derive(Debug, Default)]
pub struct Skcipher;
// TODO: impl `Skcipher` with linux kernel Crypto API.
impl Skcipher {
/// Construct a `Skcipher` instance.
pub fn new() -> Self {
Self
}
}
impl crate::util::Skcipher for Skcipher {
type Key = SkcipherKey;
type Iv = SkcipherIv;
fn encrypt(
&self,
input: &[u8],
key: &SkcipherKey,
iv: &SkcipherIv,
output: &mut [u8],
) -> Result<()> {
let mut cipher = Aes128Ctr::new_from_slices(key, iv).unwrap();
output.copy_from_slice(input);
cipher.apply_keystream(output);
Ok(())
}
fn decrypt(
&self,
input: &[u8],
key: &SkcipherKey,
iv: &SkcipherIv,
output: &mut [u8],
) -> Result<()> {
let mut cipher = Aes128Ctr::new_from_slices(key, iv).unwrap();
output.copy_from_slice(input);
cipher.apply_keystream(output);
Ok(())
}
}

View File

@ -0,0 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
pub(crate) use crate::{
error::{Errno::*, Error},
layers::bio::{BlockId, BLOCK_SIZE},
os::{Arc, Box, String, ToString, Vec, Weak},
return_errno, return_errno_with_msg,
util::{align_down, align_up, Aead as _, RandomInit, Rng as _, Skcipher as _},
};
pub(crate) type Result<T> = core::result::Result<T, Error>;
pub(crate) use core::fmt::{self, Debug};
pub(crate) use log::{debug, error, info, trace, warn};

View File

@ -0,0 +1,143 @@
// SPDX-License-Identifier: MPL-2.0
//! Get and set the current transaction of the current thread.
use core::sync::atomic::Ordering::{Acquire, Release};
use super::{Tx, TxData, TxId, TxProvider, TxStatus};
use crate::{os::CurrentThread, prelude::*};
/// The current transaction on a thread.
#[derive(Clone)]
pub struct CurrentTx<'a> {
provider: &'a TxProvider,
}
// CurrentTx is only useful and valid for the current thread
impl !Send for CurrentTx<'_> {}
impl !Sync for CurrentTx<'_> {}
impl<'a> CurrentTx<'a> {
pub(super) fn new(provider: &'a TxProvider) -> Self {
Self { provider }
}
/// Enter the context of the current TX.
///
/// While within the context of a TX, the implementation side of a TX
/// can get the current TX via `TxProvider::current`.
pub fn context<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
let tx_table = self.provider.tx_table.lock();
let tid = CurrentThread::id();
if !tx_table.contains_key(&tid) {
panic!("there should be one Tx exited on the current thread");
}
assert!(tx_table.get(&tid).unwrap().status() == TxStatus::Ongoing);
drop(tx_table);
f()
}
/// Commits the current TX.
///
/// If the returned value is `Ok`, then the TX is committed successfully.
/// Otherwise, the TX is aborted.
pub fn commit(&self) -> Result<()> {
let mut tx_table = self.provider.tx_table.lock();
let Some(mut tx) = tx_table.remove(&CurrentThread::id()) else {
panic!("there should be one Tx exited on the current thread");
};
debug_assert!(tx.status() == TxStatus::Ongoing);
let res = self.provider.call_precommit_handlers();
if res.is_ok() {
self.provider.call_commit_handlers();
tx.set_status(TxStatus::Committed);
} else {
self.provider.call_abort_handlers();
tx.set_status(TxStatus::Aborted);
}
res
}
/// Aborts the current TX.
pub fn abort(&self) {
let mut tx_table = self.provider.tx_table.lock();
let Some(mut tx) = tx_table.remove(&CurrentThread::id()) else {
panic!("there should be one Tx exited on the current thread");
};
debug_assert!(tx.status() == TxStatus::Ongoing);
self.provider.call_abort_handlers();
tx.set_status(TxStatus::Aborted);
}
/// The ID of the transaction.
pub fn id(&self) -> TxId {
self.get_current_mut_with(|tx| tx.id())
}
/// Get immutable access to some type of the per-transaction data within a closure.
///
/// # Panics
///
/// The `data_with` method must _not_ be called recursively.
pub fn data_with<T: TxData, F, R>(&self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
self.get_current_mut_with(|tx| {
let data = tx.data::<T>();
f(data)
})
}
/// Get mutable access to some type of the per-transaction data within a closure.
pub fn data_mut_with<T: TxData, F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
self.get_current_mut_with(|tx| {
let data = tx.data_mut::<T>();
f(data)
})
}
/// Get a _mutable_ reference to the current transaction of the current thread,
/// passing it to a given closure.
///
/// # Panics
///
/// The `get_current_mut_with` method must be called within the closure
/// of `set_and_exec_with`.
///
/// In addition, the `get_current_mut_with` method must _not_ be called
/// recursively.
#[allow(dropping_references)]
fn get_current_mut_with<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut Tx) -> R,
{
let mut tx_table = self.provider.tx_table.lock();
let Some(tx) = tx_table.get_mut(&CurrentThread::id()) else {
panic!("there should be one Tx exited on the current thread");
};
if tx.is_accessing_data.swap(true, Acquire) {
panic!("get_current_mut_with must not be called recursively");
}
let retval: R = f(tx);
// SAFETY. At any given time, at most one mutable reference will be constructed
// between the Acquire-Release section. And it is safe to drop `&mut Tx` after
// `Release`, since drop the reference does nothing to the `Tx` itself.
tx.is_accessing_data.store(false, Release);
retval
}
}

View File

@ -0,0 +1,435 @@
// SPDX-License-Identifier: MPL-2.0
//! Transaction management.
//!
//! Transaction management APIs serve two sides:
//!
//! * The user side of TXs uses `Tx` to use, commit, or abort TXs.
//! * The implementation side of TXs uses `TxProvider` to get notified
//! when TXs are created, committed, or aborted by register callbacks.
mod current;
use core::{
any::{Any, TypeId},
sync::atomic::{AtomicBool, AtomicU64, Ordering},
};
pub use self::current::CurrentTx;
use crate::{
os::{CurrentThread, HashMap, Mutex, RwLock, Tid},
prelude::*,
};
/// A transaction provider.
#[allow(clippy::type_complexity)]
pub struct TxProvider {
id: u64,
initializer_map: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
precommit_handlers: RwLock<Vec<Box<dyn Fn(CurrentTx<'_>) -> Result<()> + Send + Sync>>>,
commit_handlers: RwLock<Vec<Box<dyn Fn(CurrentTx<'_>) + Send + Sync>>>,
abort_handlers: RwLock<Vec<Box<dyn Fn(CurrentTx<'_>) + Send + Sync>>>,
weak_self: Weak<Self>,
tx_table: Mutex<HashMap<Tid, Tx>>,
}
impl TxProvider {
/// Creates a new TX provider.
pub fn new() -> Arc<Self> {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
Arc::new_cyclic(|weak_self| Self {
id: NEXT_ID.fetch_add(1, Ordering::Release),
initializer_map: RwLock::new(HashMap::new()),
precommit_handlers: RwLock::new(Vec::new()),
commit_handlers: RwLock::new(Vec::new()),
abort_handlers: RwLock::new(Vec::new()),
weak_self: weak_self.clone(),
tx_table: Mutex::new(HashMap::new()),
})
}
/// Creates a new TX that is attached to this TX provider.
pub fn new_tx(&self) -> CurrentTx<'_> {
let mut tx_table = self.tx_table.lock();
let tid = CurrentThread::id();
if tx_table.contains_key(&tid) {
return self.current();
}
let tx = Tx::new(self.weak_self.clone());
let _ = tx_table.insert(tid, tx);
self.current()
}
/// Get the current TX.
///
/// # Panics
///
/// The caller of this method must be within the closure passed to
/// `Tx::context`. Otherwise, the method would panic.
pub fn current(&self) -> CurrentTx<'_> {
CurrentTx::new(self)
}
/// Register a per-TX data initializer.
///
/// The registered initializer function will be called upon the creation of
/// a TX.
pub fn register_data_initializer<T>(&self, f: Box<dyn Fn() -> T + Send + Sync>)
where
T: TxData,
{
let mut initializer_map = self.initializer_map.write();
initializer_map.insert(TypeId::of::<T>(), Box::new(f));
}
fn init_data<T>(&self) -> T
where
T: TxData,
{
let initializer_map = self.initializer_map.read();
let init_fn = initializer_map
.get(&TypeId::of::<T>())
.unwrap()
.downcast_ref::<Box<dyn Fn() -> T>>()
.unwrap();
init_fn()
}
/// Register a callback for the pre-commit stage,
/// which is before the commit stage.
///
/// Committing a TX triggers the pre-commit stage as well as the commit
/// stage of the TX.
/// On the pre-commit stage, the register callbacks will be called.
/// Pre-commit callbacks are allowed to fail (unlike commit callbacks).
/// If any pre-commit callbacks failed, the TX would be aborted and
/// the commit callbacks would not get called.
pub fn register_precommit_handler<F>(&self, f: F)
where
F: Fn(CurrentTx<'_>) -> Result<()> + Send + Sync + 'static,
{
let f = Box::new(f);
let mut precommit_handlers = self.precommit_handlers.write();
precommit_handlers.push(f);
}
fn call_precommit_handlers(&self) -> Result<()> {
let current = self.current();
let precommit_handlers = self.precommit_handlers.read();
for precommit_func in precommit_handlers.iter().rev() {
precommit_func(current.clone())?;
}
Ok(())
}
/// Register a callback for the commit stage,
/// which is after the pre-commit stage.
///
/// Committing a TX triggers first the pre-commit stage of the TX and then
/// the commit stage. The callbacks for the commit stage is not allowed
/// to fail.
pub fn register_commit_handler<F>(&self, f: F)
where
F: Fn(CurrentTx<'_>) + Send + Sync + 'static,
{
let f = Box::new(f);
let mut commit_handlers = self.commit_handlers.write();
commit_handlers.push(f);
}
fn call_commit_handlers(&self) {
let current = self.current();
let commit_handlers = self.commit_handlers.read();
for commit_func in commit_handlers.iter().rev() {
commit_func(current.clone())
}
}
/// Register a callback for the abort stage.
///
/// A TX enters the abort stage when the TX is aborted by the user
/// (via `Tx::abort`) or by a callback in the pre-commit stage.
pub fn register_abort_handler<F>(&self, f: F)
where
F: Fn(CurrentTx<'_>) + Send + Sync + 'static,
{
let f = Box::new(f);
let mut abort_handlers = self.abort_handlers.write();
abort_handlers.push(f);
}
fn call_abort_handlers(&self) {
let current = self.current();
let abort_handlers = self.abort_handlers.read();
for abort_func in abort_handlers.iter().rev() {
abort_func(current.clone())
}
}
}
/// A transaction.
pub struct Tx {
id: TxId,
provider: Weak<TxProvider>,
data_map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
status: TxStatus,
is_accessing_data: AtomicBool,
}
impl Tx {
fn new(provider: Weak<TxProvider>) -> Self {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
Self {
id: NEXT_ID.fetch_add(1, Ordering::Release),
provider,
data_map: HashMap::new(),
status: TxStatus::Ongoing,
is_accessing_data: AtomicBool::new(false),
}
}
/// Returns the TX ID.
pub fn id(&self) -> TxId {
self.id
}
/// Returns the status of the TX.
pub fn status(&self) -> TxStatus {
self.status
}
/// Sets the status of the Tx.
pub fn set_status(&mut self, status: TxStatus) {
self.status = status;
}
fn provider(&self) -> Arc<TxProvider> {
self.provider.upgrade().unwrap()
}
fn data<T>(&mut self) -> &T
where
T: TxData,
{
self.data_mut::<T>()
}
fn data_mut<T>(&mut self) -> &mut T
where
T: TxData,
{
let exists = self.data_map.contains_key(&TypeId::of::<T>());
if !exists {
// Slow path, need to initialize the data
let provider = self.provider();
let data: T = provider.init_data::<T>();
self.data_map.insert(TypeId::of::<T>(), Box::new(data));
}
// Fast path
self.data_map
.get_mut(&TypeId::of::<T>())
.unwrap()
.downcast_mut::<T>()
.unwrap()
}
}
impl Drop for Tx {
fn drop(&mut self) {
assert!(
self.status() != TxStatus::Ongoing,
"transactions must be committed or aborted explicitly"
);
}
}
/// The status of a transaction.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TxStatus {
Ongoing,
Committed,
Aborted,
}
/// The ID of a transaction.
pub type TxId = u64;
/// Per-transaction data.
///
/// Using `TxProvider::register_data_initiailzer` to inject per-transaction data
/// and using `CurrentTx::data_with` or `CurrentTx::data_mut_with` to access
/// per-transaction data.
pub trait TxData: Any + Send + Sync {}
#[cfg(test)]
mod tests {
use alloc::collections::BTreeSet;
use super::*;
/// `Db<T>` is a toy implementation of in-memory database for
/// a set of items of type `T`.
///
/// The most interesting feature of `Db<T>` is the support
/// of transactions. All queries and insertions to the database must
/// be performed within transactions. These transactions ensure
/// the atomicity of insertions even in the presence of concurrent execution.
/// If transactions are aborted, their changes won't take effect.
///
/// The main limitation of `Db<T>` is that it only supports
/// querying and inserting items, but not deleting.
/// The lack of support of deletions rules out the possibilities
/// of concurrent transactions conflicting with each other.
pub struct Db<T> {
all_items: Arc<Mutex<BTreeSet<T>>>,
tx_provider: Arc<TxProvider>,
}
struct DbUpdate<T> {
new_items: BTreeSet<T>,
}
impl<T: 'static> TxData for DbUpdate<T> {}
impl<T> Db<T>
where
T: Ord + 'static,
{
/// Creates an empty database.
pub fn new() -> Self {
let new_self = Self {
all_items: Arc::new(Mutex::new(BTreeSet::new())),
tx_provider: TxProvider::new(),
};
new_self
.tx_provider
.register_data_initializer(Box::new(|| DbUpdate {
new_items: BTreeSet::<T>::new(),
}));
new_self.tx_provider.register_commit_handler({
let all_items = new_self.all_items.clone();
move |mut current: CurrentTx<'_>| {
current.data_mut_with(|update: &mut DbUpdate<T>| {
let mut all_items = all_items.lock();
all_items.append(&mut update.new_items);
});
}
});
new_self
}
/// Creates a new DB transaction.
pub fn new_tx(&self) -> CurrentTx<'_> {
self.tx_provider.new_tx()
}
/// Returns whether an item is contained.
///
/// # Transaction
///
/// This method must be called within the context of a transaction.
pub fn contains(&self, item: &T) -> bool {
let is_new_item = {
let current_tx = self.tx_provider.current();
current_tx.data_with(|update: &DbUpdate<T>| update.new_items.contains(item))
};
if is_new_item {
return true;
}
let all_items = self.all_items.lock();
all_items.contains(item)
}
/// Inserts a new item into the DB.
///
/// # Transaction
///
/// This method must be called within the context of a transaction.
pub fn insert(&self, item: T) {
let all_items = self.all_items.lock();
if all_items.contains(&item) {
return;
}
let mut current_tx = self.tx_provider.current();
current_tx.data_mut_with(|update: &mut DbUpdate<_>| {
update.new_items.insert(item);
});
}
/// Collects all items of the DB.
///
/// # Transaction
///
/// This method must be called within the context of a transaction.
pub fn collect(&self) -> Vec<T>
where
T: Copy,
{
let all_items = self.all_items.lock();
let current_tx = self.tx_provider.current();
current_tx.data_with(|update: &DbUpdate<T>| {
all_items.union(&update.new_items).cloned().collect()
})
}
/// Returns the number of items in the DB.
///
/// # Transaction
///
/// This method must be called within the context of a transaction.
pub fn len(&self) -> usize {
let all_items = self.all_items.lock();
let current_tx = self.tx_provider.current();
let new_items_len = current_tx.data_with(|update: &DbUpdate<T>| update.new_items.len());
all_items.len() + new_items_len
}
}
#[test]
fn commit_takes_effect() {
let db: Db<u32> = Db::new();
let items = vec![1, 2, 3];
new_tx_and_insert_items::<u32, alloc::vec::IntoIter<u32>>(&db, items.clone().into_iter())
.commit()
.unwrap();
assert!(collect_items(&db) == items);
}
#[test]
fn abort_has_no_effect() {
let db: Db<u32> = Db::new();
let items = vec![1, 2, 3];
new_tx_and_insert_items::<u32, alloc::vec::IntoIter<u32>>(&db, items.into_iter()).abort();
assert!(collect_items(&db).len() == 0);
}
fn new_tx_and_insert_items<T, I>(db: &Db<T>, new_items: I) -> Tx
where
I: Iterator<Item = T>,
T: Copy + Ord + 'static,
{
let mut tx = db.new_tx();
tx.context(move || {
for new_item in new_items {
db.insert(new_item);
}
});
tx
}
fn collect_items<T>(db: &Db<T>) -> Vec<T>
where
T: Copy + Ord + 'static,
{
let mut tx = db.new_tx();
let items = tx.context(|| db.collect());
tx.commit().unwrap();
items
}
}

View File

@ -0,0 +1,302 @@
// SPDX-License-Identifier: MPL-2.0
use core::ops::Index;
use bittle::{Bits, BitsMut};
use serde::{Deserialize, Serialize};
use crate::prelude::*;
/// A compact array of bits.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct BitMap {
bits: Vec<u64>,
nbits: usize,
}
impl BitMap {
/// The one bit represents `true`.
const ONE: bool = true;
/// The zero bit represents `false`.
const ZERO: bool = false;
/// Create a new `BitMap` by repeating the `value` for the desired length.
pub fn repeat(value: bool, nbits: usize) -> Self {
let vec_len = nbits.div_ceil(64);
let mut bits = Vec::with_capacity(vec_len);
if value == Self::ONE {
bits.resize(vec_len, !0u64);
} else {
bits.resize(vec_len, 0u64);
}
// Set the unused bits in the last u64 with zero.
if nbits % 64 != 0 {
bits[vec_len - 1]
.iter_ones()
.filter(|index| (*index as usize) >= nbits % 64)
.for_each(|index| bits[vec_len - 1].clear_bit(index));
}
Self { bits, nbits }
}
/// Return the total number of bits.
pub fn len(&self) -> usize {
self.nbits
}
fn check_index(&self, index: usize) {
if index >= self.len() {
panic!(
"bitmap index {} is out of range, total bits {}",
index, self.nbits,
);
}
}
/// Test if the given bit is set.
///
/// Return `true` if the given bit is one bit.
///
/// # Panics
///
/// The `index` must be within the total number of bits. Otherwise, this method panics.
pub fn test_bit(&self, index: usize) -> bool {
self.check_index(index);
self.bits.test_bit(index as _)
}
/// Set the given bit with one bit.
///
/// # Panics
///
/// The `index` must be within the total number of bits. Otherwise, this method panics.
pub fn set_bit(&mut self, index: usize) {
self.check_index(index);
self.bits.set_bit(index as _);
}
/// Clear the given bit with zero bit.
///
/// # Panics
///
/// The `index` must be within the total number of bits. Otherwise, this method panics.
pub fn clear_bit(&mut self, index: usize) {
self.check_index(index);
self.bits.clear_bit(index as _)
}
/// Set the given bit with `value`.
///
/// One bit is set for `true`, and zero bit for `false`.
///
/// # Panics
///
/// The `index` must be within the total number of bits. Otherwise, this method panics.
pub fn set(&mut self, index: usize, value: bool) {
if value == Self::ONE {
self.set_bit(index);
} else {
self.clear_bit(index);
}
}
fn bits_not_in_use(&self) -> usize {
self.bits.len() * 64 - self.nbits
}
/// Get the number of one bits in the bitmap.
pub fn count_ones(&self) -> usize {
self.bits.count_ones() as _
}
/// Get the number of zero bits in the bitmap.
pub fn count_zeros(&self) -> usize {
let total_zeros = self.bits.count_zeros() as usize;
total_zeros - self.bits_not_in_use()
}
/// Find the index of the first one bit, starting from the given index (inclusively).
///
/// Return `None` if no one bit is found.
///
/// # Panics
///
/// The `from` index must be within the total number of bits. Otherwise, this method panics.
pub fn first_one(&self, from: usize) -> Option<usize> {
self.check_index(from);
let first_u64_index = from / 64;
self.bits[first_u64_index..]
.iter_ones()
.map(|index| first_u64_index * 64 + (index as usize))
.find(|&index| index >= from)
}
/// Find `count` indexes of the first one bits, starting from the given index (inclusively).
///
/// Return `None` if fewer than `count` one bits are found.
///
/// # Panics
///
/// The `from + count` index must be within the total number of bits. Otherwise, this method panics.
pub fn first_ones(&self, from: usize, count: usize) -> Option<Vec<usize>> {
self.check_index(from + count - 1);
let first_u64_index = from / 64;
let ones: Vec<_> = self.bits[first_u64_index..]
.iter_ones()
.map(|index| first_u64_index * 64 + (index as usize))
.filter(|&index| index >= from)
.take(count)
.collect();
if ones.len() == count {
Some(ones)
} else {
None
}
}
/// Find the index of the last one bit.
///
/// Return `None` if no one bit is found.
pub fn last_one(&self) -> Option<usize> {
self.bits
.iter_ones()
.rev()
.map(|index| index as usize)
.next()
}
/// Find the index of the first zero bit, starting from the given index (inclusively).
///
/// Return `None` if no zero bit is found.
///
/// # Panics
///
/// The `from` index must be within the total number of bits. Otherwise, this method panics.
pub fn first_zero(&self, from: usize) -> Option<usize> {
self.check_index(from);
let first_u64_index = from / 64;
self.bits[first_u64_index..]
.iter_zeros()
.map(|index| first_u64_index * 64 + (index as usize))
.find(|&index| index >= from && index < self.len())
}
/// Find `count` indexes of the first zero bits, starting from the given index (inclusively).
///
/// Return `None` if fewer than `count` zero bits are found.
///
/// # Panics
///
/// The `from + count` index must be within the total number of bits. Otherwise, this method panics.
pub fn first_zeros(&self, from: usize, count: usize) -> Option<Vec<usize>> {
self.check_index(from + count - 1);
let first_u64_index = from / 64;
let zeros: Vec<_> = self.bits[first_u64_index..]
.iter_zeros()
.map(|index| first_u64_index * 64 + (index as usize))
.filter(|&index| index >= from && index < self.len())
.take(count)
.collect();
if zeros.len() == count {
Some(zeros)
} else {
None
}
}
/// Find the index of the last zero bit.
///
/// Return `None` if no zero bit is found.
pub fn last_zero(&self) -> Option<usize> {
self.bits
.iter_zeros()
.rev()
.skip(self.bits_not_in_use())
.map(|index| index as usize)
.next()
}
}
impl Index<usize> for BitMap {
type Output = bool;
fn index(&self, index: usize) -> &Self::Output {
if self.test_bit(index) {
&BitMap::ONE
} else {
&BitMap::ZERO
}
}
}
#[cfg(test)]
mod tests {
use super::BitMap;
#[test]
fn all_true() {
let bm = BitMap::repeat(true, 100);
assert_eq!(bm.len(), 100);
assert_eq!(bm.count_ones(), 100);
assert_eq!(bm.count_zeros(), 0);
}
#[test]
fn all_false() {
let bm = BitMap::repeat(false, 100);
assert_eq!(bm.len(), 100);
assert_eq!(bm.count_ones(), 0);
assert_eq!(bm.count_zeros(), 100);
}
#[test]
fn bit_ops() {
let mut bm = BitMap::repeat(false, 100);
assert_eq!(bm.count_ones(), 0);
bm.set_bit(32);
assert_eq!(bm.count_ones(), 1);
assert_eq!(bm.test_bit(32), true);
bm.set(64, true);
assert_eq!(bm.count_ones(), 2);
assert_eq!(bm.test_bit(64), true);
bm.clear_bit(32);
assert_eq!(bm.count_ones(), 1);
assert_eq!(bm.test_bit(32), false);
bm.set(64, false);
assert_eq!(bm.count_ones(), 0);
assert_eq!(bm.test_bit(64), false);
}
#[test]
fn find_first_last() {
let mut bm = BitMap::repeat(false, 100);
bm.set_bit(64);
assert_eq!(bm.first_one(0), Some(64));
assert_eq!(bm.first_one(64), Some(64));
assert_eq!(bm.first_one(65), None);
assert_eq!(bm.first_ones(0, 1), Some(vec![64]));
assert_eq!(bm.first_ones(0, 2), None);
assert_eq!(bm.last_one(), Some(64));
let mut bm = BitMap::repeat(true, 100);
bm.clear_bit(64);
assert_eq!(bm.first_zero(0), Some(64));
assert_eq!(bm.first_zero(64), Some(64));
assert_eq!(bm.first_zero(65), None);
assert_eq!(bm.first_zeros(0, 1), Some(vec![64]));
assert_eq!(bm.first_zeros(0, 2), None);
assert_eq!(bm.last_zero(), Some(64));
}
}

View File

@ -0,0 +1,89 @@
// SPDX-License-Identifier: MPL-2.0
use core::ops::Deref;
use crate::prelude::Result;
/// Random initialization for Key, Iv and Mac.
pub trait RandomInit: Default {
fn random() -> Self;
}
/// Authenticated Encryption with Associated Data (AEAD) algorithm.
pub trait Aead {
type Key: Deref<Target = [u8]> + RandomInit;
type Iv: Deref<Target = [u8]> + RandomInit;
type Mac: Deref<Target = [u8]> + RandomInit;
/// Encrypt plaintext referred by `input`, with a secret `Key`,
/// initialization vector `Iv` and additional associated data `aad`.
///
/// If the operation succeed, the ciphertext will be written to `output`
/// and a message authentication code `Mac` will be returned. Or else,
/// return an `Error` on any fault.
fn encrypt(
&self,
input: &[u8],
key: &Self::Key,
iv: &Self::Iv,
aad: &[u8],
output: &mut [u8],
) -> Result<Self::Mac>;
/// Decrypt ciphertext referred by `input`, with a secret `Key` and
/// message authentication code `Mac`, initialization vector `Iv` and
/// additional associated data `aad`.
///
/// If the operation succeed, the plaintext will be written to `output`.
/// Or else, return an `Error` on any fault.
fn decrypt(
&self,
input: &[u8],
key: &Self::Key,
iv: &Self::Iv,
aad: &[u8],
mac: &Self::Mac,
output: &mut [u8],
) -> Result<()>;
}
/// Symmetric key cipher algorithm.
pub trait Skcipher {
type Key: Deref<Target = [u8]> + RandomInit;
type Iv: Deref<Target = [u8]> + RandomInit;
/// Encrypt plaintext referred by `input`, with a secret `Key` and
/// initialization vector `Iv`.
///
/// If the operation succeed, the ciphertext will be written to `output`.
/// Or else, return an `Error` on any fault.
fn encrypt(
&self,
input: &[u8],
key: &Self::Key,
iv: &Self::Iv,
output: &mut [u8],
) -> Result<()>;
/// Decrypt ciphertext referred by `input` with a secret `Key` and
/// initialization vector `Iv`.
///
/// If the operation succeed, the plaintext will be written to `output`.
/// Or else, return an `Error` on any fault.
fn decrypt(
&self,
input: &[u8],
key: &Self::Key,
iv: &Self::Iv,
output: &mut [u8],
) -> Result<()>;
}
/// Random number generator.
pub trait Rng {
/// Create an instance, with `seed` to provide secure entropy.
fn new(seed: &[u8]) -> Self;
/// Fill `dest` with random bytes.
fn fill_bytes(&self, dest: &mut [u8]) -> Result<()>;
}

View File

@ -0,0 +1,105 @@
// SPDX-License-Identifier: MPL-2.0
use core::{
fmt,
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, Ordering},
};
use crate::prelude::*;
/// An object that may be deleted lazily.
///
/// Lazy-deletion is a technique to postpone the real deletion of an object.
/// This technique allows an object to remain usable even after a decision
/// to delete the object has been made. Of course. After the "real" deletion
/// is carried out, the object will no longer be usable.
///
/// A classic example is file deletion in UNIX file systems.
///
/// ```ignore
/// int fd = open("path/to/my_file", O_RDONLY);
/// unlink("path/to/my_file");
/// // fd is still valid after unlink
/// ```
///
/// `LazyDelete<T>` enables lazy deletion of any object of `T`.
/// Here is a simple example.
///
/// ```
/// use sworndisk_v2::lazy_delete::*;
///
/// let lazy_delete_u32 = LazyDelete::new(123_u32, |obj| {
/// println!("the real deletion happens in this closure");
/// });
///
/// // The object is still usable after it is deleted (lazily)
/// LazyDelete::delete(&lazy_delete_u32);
/// assert!(*lazy_delete_u32 == 123);
///
/// // The deletion operation will be carried out when it is dropped
/// drop(lazy_delete_u32);
/// ```
#[allow(clippy::type_complexity)]
pub struct LazyDelete<T> {
obj: T,
is_deleted: AtomicBool,
delete_fn: Option<Box<dyn FnOnce(&mut T) + Send + Sync>>,
}
impl<T: fmt::Debug> fmt::Debug for LazyDelete<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LazyDelete")
.field("obj", &self.obj)
.field("is_deleted", &Self::is_deleted(self))
.finish()
}
}
impl<T> LazyDelete<T> {
/// Creates a new instance of `LazyDelete`.
///
/// The `delete_fn` will be called only if this instance of `LazyDelete` is
/// marked deleted by the `delete` method and only when this instance
/// of `LazyDelete` is dropped.
pub fn new<F: FnOnce(&mut T) + Send + Sync + 'static>(obj: T, delete_fn: F) -> Self {
Self {
obj,
is_deleted: AtomicBool::new(false),
delete_fn: Some(Box::new(delete_fn) as _),
}
}
/// Mark this instance deleted.
pub fn delete(this: &Self) {
this.is_deleted.store(true, Ordering::Release);
}
/// Returns whether this instance has been marked deleted.
pub fn is_deleted(this: &Self) -> bool {
this.is_deleted.load(Ordering::Acquire)
}
}
impl<T> Deref for LazyDelete<T> {
type Target = T;
fn deref(&self) -> &T {
&self.obj
}
}
impl<T> DerefMut for LazyDelete<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.obj
}
}
impl<T> Drop for LazyDelete<T> {
fn drop(&mut self) {
if Self::is_deleted(self) {
let delete_fn = self.delete_fn.take().unwrap();
(delete_fn)(&mut self.obj);
}
}
}

View File

@ -0,0 +1,22 @@
// SPDX-License-Identifier: MPL-2.0
//! Utilities.
mod bitmap;
mod crypto;
mod lazy_delete;
pub use self::{
bitmap::BitMap,
crypto::{Aead, RandomInit, Rng, Skcipher},
lazy_delete::LazyDelete,
};
/// Aligns `x` up to the next multiple of `align`.
pub(crate) const fn align_up(x: usize, align: usize) -> usize {
x.div_ceil(align) * align
}
/// Aligns `x` down to the previous multiple of `align`.
pub(crate) const fn align_down(x: usize, align: usize) -> usize {
(x / align) * align
}