From 56a137dc56fae5bec523b145b700d9eb6d3f606d Mon Sep 17 00:00:00 2001 From: Qingsong Chen <changxian.cqs@antgroup.com> Date: Fri, 27 Dec 2024 11:49:46 +0000 Subject: [PATCH] Add mlsdisk as a component Co-authored-by: Shaowei Song <songshaowei.ssw@antgroup.com> --- Cargo.lock | 602 +++++-- Cargo.toml | 1 + Components.toml | 1 + Makefile | 1 + kernel/Cargo.toml | 1 + kernel/comps/mlsdisk/Cargo.toml | 22 + kernel/comps/mlsdisk/src/error.rs | 95 ++ .../mlsdisk/src/layers/0-bio/block_buf.rs | 243 +++ .../mlsdisk/src/layers/0-bio/block_log.rs | 133 ++ .../mlsdisk/src/layers/0-bio/block_ring.rs | 114 ++ .../mlsdisk/src/layers/0-bio/block_set.rs | 227 +++ kernel/comps/mlsdisk/src/layers/0-bio/mod.rs | 24 + .../src/layers/1-crypto/crypto_blob.rs | 358 ++++ .../src/layers/1-crypto/crypto_chain.rs | 401 +++++ .../mlsdisk/src/layers/1-crypto/crypto_log.rs | 1192 +++++++++++++ .../comps/mlsdisk/src/layers/1-crypto/mod.rs | 18 + .../comps/mlsdisk/src/layers/2-edit/edits.rs | 154 ++ .../mlsdisk/src/layers/2-edit/journal.rs | 1081 ++++++++++++ kernel/comps/mlsdisk/src/layers/2-edit/mod.rs | 13 + .../comps/mlsdisk/src/layers/3-log/chunk.rs | 480 ++++++ kernel/comps/mlsdisk/src/layers/3-log/mod.rs | 14 + .../comps/mlsdisk/src/layers/3-log/raw_log.rs | 1235 ++++++++++++++ .../comps/mlsdisk/src/layers/3-log/tx_log.rs | 1491 +++++++++++++++++ .../mlsdisk/src/layers/4-lsm/compaction.rs | 132 ++ .../mlsdisk/src/layers/4-lsm/mem_table.rs | 402 +++++ kernel/comps/mlsdisk/src/layers/4-lsm/mod.rs | 79 + .../src/layers/4-lsm/range_query_ctx.rs | 96 ++ .../comps/mlsdisk/src/layers/4-lsm/sstable.rs | 779 +++++++++ .../mlsdisk/src/layers/4-lsm/tx_lsm_tree.rs | 1037 ++++++++++++ kernel/comps/mlsdisk/src/layers/4-lsm/wal.rs | 279 +++ kernel/comps/mlsdisk/src/layers/5-disk/bio.rs | 291 ++++ .../mlsdisk/src/layers/5-disk/block_alloc.rs | 403 +++++ .../mlsdisk/src/layers/5-disk/data_buf.rs | 137 ++ kernel/comps/mlsdisk/src/layers/5-disk/mod.rs | 41 + .../mlsdisk/src/layers/5-disk/sworndisk.rs | 881 ++++++++++ kernel/comps/mlsdisk/src/layers/mod.rs | 14 + kernel/comps/mlsdisk/src/lib.rs | 27 + kernel/comps/mlsdisk/src/os/mod.rs | 404 +++++ kernel/comps/mlsdisk/src/prelude.rs | 15 + kernel/comps/mlsdisk/src/tx/current.rs | 143 ++ kernel/comps/mlsdisk/src/tx/mod.rs | 435 +++++ kernel/comps/mlsdisk/src/util/bitmap.rs | 302 ++++ kernel/comps/mlsdisk/src/util/crypto.rs | 89 + kernel/comps/mlsdisk/src/util/lazy_delete.rs | 105 ++ kernel/comps/mlsdisk/src/util/mod.rs | 22 + 45 files changed, 13832 insertions(+), 182 deletions(-) create mode 100644 kernel/comps/mlsdisk/Cargo.toml create mode 100644 kernel/comps/mlsdisk/src/error.rs create mode 100644 kernel/comps/mlsdisk/src/layers/0-bio/block_buf.rs create mode 100644 kernel/comps/mlsdisk/src/layers/0-bio/block_log.rs create mode 100644 kernel/comps/mlsdisk/src/layers/0-bio/block_ring.rs create mode 100644 kernel/comps/mlsdisk/src/layers/0-bio/block_set.rs create mode 100644 kernel/comps/mlsdisk/src/layers/0-bio/mod.rs create mode 100644 kernel/comps/mlsdisk/src/layers/1-crypto/crypto_blob.rs create mode 100644 kernel/comps/mlsdisk/src/layers/1-crypto/crypto_chain.rs create mode 100644 kernel/comps/mlsdisk/src/layers/1-crypto/crypto_log.rs create mode 100644 kernel/comps/mlsdisk/src/layers/1-crypto/mod.rs create mode 100644 kernel/comps/mlsdisk/src/layers/2-edit/edits.rs create mode 100644 kernel/comps/mlsdisk/src/layers/2-edit/journal.rs create mode 100644 kernel/comps/mlsdisk/src/layers/2-edit/mod.rs create mode 100644 kernel/comps/mlsdisk/src/layers/3-log/chunk.rs create mode 100644 kernel/comps/mlsdisk/src/layers/3-log/mod.rs create mode 100644 kernel/comps/mlsdisk/src/layers/3-log/raw_log.rs create mode 100644 kernel/comps/mlsdisk/src/layers/3-log/tx_log.rs create mode 100644 kernel/comps/mlsdisk/src/layers/4-lsm/compaction.rs create mode 100644 kernel/comps/mlsdisk/src/layers/4-lsm/mem_table.rs create mode 100644 kernel/comps/mlsdisk/src/layers/4-lsm/mod.rs create mode 100644 kernel/comps/mlsdisk/src/layers/4-lsm/range_query_ctx.rs create mode 100644 kernel/comps/mlsdisk/src/layers/4-lsm/sstable.rs create mode 100644 kernel/comps/mlsdisk/src/layers/4-lsm/tx_lsm_tree.rs create mode 100644 kernel/comps/mlsdisk/src/layers/4-lsm/wal.rs create mode 100644 kernel/comps/mlsdisk/src/layers/5-disk/bio.rs create mode 100644 kernel/comps/mlsdisk/src/layers/5-disk/block_alloc.rs create mode 100644 kernel/comps/mlsdisk/src/layers/5-disk/data_buf.rs create mode 100644 kernel/comps/mlsdisk/src/layers/5-disk/mod.rs create mode 100644 kernel/comps/mlsdisk/src/layers/5-disk/sworndisk.rs create mode 100644 kernel/comps/mlsdisk/src/layers/mod.rs create mode 100644 kernel/comps/mlsdisk/src/lib.rs create mode 100644 kernel/comps/mlsdisk/src/os/mod.rs create mode 100644 kernel/comps/mlsdisk/src/prelude.rs create mode 100644 kernel/comps/mlsdisk/src/tx/current.rs create mode 100644 kernel/comps/mlsdisk/src/tx/mod.rs create mode 100644 kernel/comps/mlsdisk/src/util/bitmap.rs create mode 100644 kernel/comps/mlsdisk/src/util/crypto.rs create mode 100644 kernel/comps/mlsdisk/src/util/lazy_delete.rs create mode 100644 kernel/comps/mlsdisk/src/util/mod.rs diff --git a/Cargo.lock b/Cargo.lock index c3558ee3..f368dd39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,10 +20,45 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] -name = "ahash" -version = "0.8.9" +name = "aead" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" +checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" +dependencies = [ + "generic-array", +] + +[[package]] +name = "aes" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", + "opaque-debug", +] + +[[package]] +name = "aes-gcm" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df5f85a83a7d8b0442b6aa7b504b8212c1733da07b98aae43d4bc21b2cb3cdf6" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "once_cell", @@ -37,9 +72,9 @@ version = "0.1.0" [[package]] name = "allocator-api2" -version = "0.2.16" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "aml" @@ -74,7 +109,7 @@ dependencies = [ "jhash", "ostd", "smoltcp", - "spin 0.9.8", + "spin", "static_assertions", "takeable", ] @@ -91,7 +126,7 @@ dependencies = [ "int-to-c-enum", "log", "ostd", - "spin 0.9.8", + "spin", "static_assertions", ] @@ -104,7 +139,7 @@ dependencies = [ "component", "log", "ostd", - "spin 0.9.8", + "spin", ] [[package]] @@ -116,7 +151,7 @@ dependencies = [ "font8x8", "log", "ostd", - "spin 0.9.8", + "spin", ] [[package]] @@ -131,7 +166,7 @@ dependencies = [ "int-to-c-enum", "log", "ostd", - "spin 0.9.8", + "spin", ] [[package]] @@ -145,7 +180,27 @@ dependencies = [ "log", "ostd", "owo-colors 3.5.0", - "spin 0.9.8", + "spin", +] + +[[package]] +name = "aster-mlsdisk" +version = "0.1.0" +dependencies = [ + "aes-gcm", + "aster-block", + "bittle", + "ctr", + "hashbrown 0.14.5", + "inherit-methods-macro", + "lending-iterator", + "log", + "lru", + "ostd", + "ostd-pod", + "postcard", + "serde", + "static_assertions", ] [[package]] @@ -162,7 +217,7 @@ dependencies = [ "int-to-c-enum", "log", "ostd", - "spin 0.9.8", + "spin", ] [[package]] @@ -177,6 +232,7 @@ dependencies = [ "aster-framebuffer", "aster-input", "aster-logger", + "aster-mlsdisk", "aster-network", "aster-rights", "aster-rights-proc", @@ -194,7 +250,7 @@ dependencies = [ "cpio-decoder", "fixed", "getset", - "hashbrown", + "hashbrown 0.14.5", "id-alloc", "inherit-methods-macro", "int-to-c-enum", @@ -208,7 +264,7 @@ dependencies = [ "paste", "rand", "riscv", - "spin 0.9.8", + "spin", "static_assertions", "takeable", "tdx-guest", @@ -246,7 +302,7 @@ dependencies = [ "component", "intrusive-collections", "ostd", - "spin 0.9.8", + "spin", ] [[package]] @@ -259,7 +315,7 @@ dependencies = [ "component", "log", "ostd", - "spin 0.9.8", + "spin", ] [[package]] @@ -293,7 +349,7 @@ dependencies = [ "int-to-c-enum", "log", "ostd", - "spin 0.9.8", + "spin", "typeflags-util", ] @@ -303,14 +359,23 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", +] + +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", ] [[package]] name = "autocfg" -version = "1.1.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "az" @@ -336,6 +401,12 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "bittle" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4650bc513f4078b7ad8dfbbca0455a1ee6cc1325dc21b5995340a7ab3d80ac5b" + [[package]] name = "bitvec" version = "1.0.1" @@ -356,35 +427,35 @@ checksum = "a7913f22349ffcfc6ca0ca9a656ec26cfbba538ed49c31a273dff2c5d1ea83d9" [[package]] name = "bytemuck" -version = "1.17.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fd4c6dcc3b0aea2f5c0b4b82c2b15fe39ddbc76041a310848f4706edf76bb31" +checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.7.1" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.4.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cfg-if" @@ -394,13 +465,28 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "num-traits", ] +[[package]] +name = "cipher" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +dependencies = [ + "generic-array", +] + +[[package]] +name = "cobs" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" + [[package]] name = "component" version = "0.1.0" @@ -457,19 +543,28 @@ dependencies = [ ] [[package]] -name = "crc32fast" -version = "1.3.2" +name = "cpufeatures" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" dependencies = [ "cfg-if", ] [[package]] name = "critical-section" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64009896348fc5af4222e9cf7d7d82a95a256c634ebcf61c53e4ea461422242" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" [[package]] name = "crunchy" @@ -486,6 +581,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ctr" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "049bb91fb4aaf0e3c7efa6cd5ef877dbbbd15b39dad06d9948de4ec8a75761ea" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.13.4" @@ -523,15 +627,15 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "defmt" -version = "0.3.8" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a99dd22262668b887121d4672af5a64b238f026099f1a2a1b322066c9ecfe9e0" +checksum = "86f6162c53f659f65d00619fe31f14556a6e9f8752ccc4a41bd177ffcf3d6130" dependencies = [ "bitflags 1.3.2", "defmt-macros", @@ -539,31 +643,34 @@ dependencies = [ [[package]] name = "defmt-macros" -version = "0.3.6" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54f0216f6c5acb5ae1a47050a6645024e6edafc2ee32d421955eccfef12ef92e" +checksum = "9d135dd939bad62d7490b0002602d35b358dce5fd9233a709d3c1ef467d4bde6" dependencies = [ "defmt-parser", - "proc-macro-error", + "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] name = "defmt-parser" -version = "0.3.3" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "269924c02afd7f94bc4cecbfa5c379f6ffcf9766b3408fe63d22c728654eccd0" +checksum = "3983b127f13995e68c1e29071e5d115cd96f215ccb5e6812e3728cd6f92653b3" dependencies = [ "thiserror", ] [[package]] name = "deranged" -version = "0.3.7" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7684a49fb1af197853ef7b2ee694bc1f5b4179556f1e5710e1760c5db6f5e929" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] [[package]] name = "derive_more" @@ -582,15 +689,15 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", "unicode-xid", ] [[package]] name = "either" -version = "1.9.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "embedded-hal" @@ -657,6 +764,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "font8x8" version = "0.2.7" @@ -670,10 +783,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] -name = "getrandom" -version = "0.2.10" +name = "generic-array" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", @@ -682,32 +805,42 @@ dependencies = [ [[package]] name = "getset" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9" +checksum = "f636605b743120a8d32ed92fc27b6cde1a769f8f936c065151eb66f88ded513c" dependencies = [ - "proc-macro-error", + "proc-macro-error2", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.91", +] + +[[package]] +name = "ghash" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" +dependencies = [ + "opaque-debug", + "polyval", ] [[package]] name = "ghost" -version = "0.1.14" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba330b70a5341d3bc730b8e205aaee97ddab5d9c448c4f51a7c2d924266fa8f9" +checksum = "39b697dbd8bfcc35d0ee91698aaa379af096368ba8837d279cc097b276edda45" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "gimli" @@ -725,6 +858,15 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + [[package]] name = "hash32" version = "0.3.1" @@ -736,12 +878,38 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", "allocator-api2", + "serde", +] + +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32 0.2.1", + "rustc_version", + "serde", + "spin", + "stable_deref_trait", ] [[package]] @@ -750,7 +918,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" dependencies = [ - "hash32", + "hash32 0.3.1", "stable_deref_trait", ] @@ -778,12 +946,12 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "indexmap" -version = "2.2.3" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -810,14 +978,14 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] name = "intrusive-collections" -version = "0.9.6" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b694dc9f70c3bda874626d2aed13b780f137aab435f4e9814121955cf706122e" +checksum = "189d0897e4cbe8c75efedf3502c18c887b05046e59d28404d4d8e46cbc4d1e86" dependencies = [ "memoffset", ] @@ -856,11 +1024,11 @@ version = "0.1.0" [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin 0.5.2", + "spin", ] [[package]] @@ -890,9 +1058,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.153" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libflate" @@ -914,7 +1082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" dependencies = [ "core2", - "hashbrown", + "hashbrown 0.14.5", "rle-decode-fast", ] @@ -953,9 +1121,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" dependencies = [ "autocfg", "scopeguard", @@ -963,17 +1131,17 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lru" -version = "0.12.3" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -1000,9 +1168,9 @@ checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memoffset" @@ -1015,9 +1183,9 @@ dependencies = [ [[package]] name = "multiboot2" -version = "0.23.0" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8bcd36c1256cbfbabbd8da34d487a64b45bea009a869185708735951792f0f3" +checksum = "43fcee184de68e344a888bc4f2b9d6b2f2f527cae8cedbb4d62a4df727d1ceae" dependencies = [ "bitflags 2.6.0", "derive_more", @@ -1029,9 +1197,9 @@ dependencies = [ [[package]] name = "multiboot2-common" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9f510cab65715c2b358a50d7d03c022e7c1499084ae6d420b0e594a9ca30853" +checksum = "fbabf8d9980d55576ba487924fa0b9a467fc0012b996b93d2319904f168ed8ab" dependencies = [ "derive_more", "ptr_meta", @@ -1086,6 +1254,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-derive" version = "0.4.2" @@ -1094,7 +1268,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] @@ -1138,16 +1312,22 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "osdk-test-kernel" version = "0.11.1" dependencies = [ "ostd", - "owo-colors 4.0.0", + "owo-colors 4.1.0", ] [[package]] @@ -1164,7 +1344,7 @@ dependencies = [ "cfg-if", "const-assert", "fdt", - "gimli 0.28.0", + "gimli 0.28.1", "iced-x86", "id-alloc", "inherit-methods-macro", @@ -1182,7 +1362,7 @@ dependencies = [ "riscv", "sbi-rt", "smallvec", - "spin 0.9.8", + "spin", "static_assertions", "tdx-guest", "unwinding", @@ -1199,7 +1379,7 @@ dependencies = [ "proc-macro2", "quote", "rand", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] @@ -1232,15 +1412,15 @@ checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" [[package]] name = "owo-colors" -version = "4.0.0" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caff54706df99d2a78a5a4e3455ff45448d81ef1bb63c22cd14052ca0e993a3f" +checksum = "fb37767f6569cd834a413442455e0f066d0d522de8630436e2a1761d9726ba56" [[package]] name = "paste" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "polonius-the-crab" @@ -1249,40 +1429,70 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a69ee997a6282f8462abf1e0d8c38c965e968799e912b3bed8c9e8a28c2f9f" [[package]] -name = "ppv-lite86" -version = "0.2.17" +name = "polyval" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1" dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", ] [[package]] -name = "proc-macro-error-attr" -version = "1.0.4" +name = "postcard" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +checksum = "170a2601f67cc9dba8edd8c4870b15f71a6a2dc196daec8c83f72b59dff628a8" +dependencies = [ + "cobs", + "heapless 0.7.17", + "serde", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" dependencies = [ "proc-macro2", "quote", - "version_check", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.91", ] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -1315,9 +1525,9 @@ checksum = "8bb0fd6580eeed0103c054e3fba2c2618ff476943762f28a645b63b8692b21c9" [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -1384,10 +1594,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" [[package]] -name = "rustversion" -version = "1.0.14" +name = "rustc_version" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustversion" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "sbi-rt" @@ -1411,30 +1630,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] -name = "serde" -version = "1.0.196" +name = "semver" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" + +[[package]] +name = "serde" +version = "1.0.216" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] name = "serde_spanned" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" dependencies = [ "serde", ] @@ -1454,17 +1679,11 @@ dependencies = [ "byteorder", "cfg-if", "defmt", - "heapless", + "heapless 0.8.0", "log", "managed", ] -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - [[package]] name = "spin" version = "0.9.8" @@ -1501,6 +1720,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "subtle" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" + [[package]] name = "syn" version = "1.0.109" @@ -1514,9 +1739,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "d53cbcb5a243bd33b7858b1d7f4aca2153490815872d86d955d6ea29f743c035" dependencies = [ "proc-macro2", "quote", @@ -1549,46 +1774,48 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.44" +version = "2.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "611040a08a0439f8248d1990b111c95baa9c704c805fa1f62104b39655fd7f90" +checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.44" +version = "2.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" +checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] name = "time" -version = "0.3.25" +version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fdd63d58b18d663fbdf70e049f00a22c8e42be082203be7f26589213cd75ea" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", + "num-conv", + "powerfmt", "serde", "time-core", ] [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "toml" -version = "0.7.6" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" +checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ "serde", "serde_spanned", @@ -1598,18 +1825,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.19.14" +version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap", "serde", @@ -1641,11 +1868,11 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "uart_16550" -version = "0.3.0" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dc00444796f6c71f47c85397a35e9c4dbf9901902ac02386940d178e2b78687" +checksum = "e492212ac378a5e00da953718dafb1340d9fbaf4f27d6f3c5cab03d931d1c049" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", "rustversion", "x86", ] @@ -1684,7 +1911,7 @@ checksum = "c19ee3a01d435eda42cb9931269b349d28a1762f91ddf01c68d276f74b957cc3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] [[package]] @@ -1700,21 +1927,31 @@ dependencies = [ [[package]] name = "uguid" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ef516f0806c5f61da6aa95125d0eb2d91cc95b2df426c06bde8be657282aee5" +checksum = "ab14ea9660d240e7865ce9d54ecdbd1cd9fa5802ae6f4512f093c7907e921533" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-xid" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "universal-hash" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f214e8f697e925001e66ec2c6e37a4ef93f0f78c2eed7814394e10c62025b05" +dependencies = [ + "generic-array", + "subtle", +] [[package]] name = "unwinding" @@ -1727,15 +1964,15 @@ dependencies = [ [[package]] name = "utf8parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "volatile" @@ -1756,9 +1993,9 @@ dependencies = [ [[package]] name = "vte_generate_state_changes" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d257817081c7dffcdbab24b9e62d2def62e2ff7d00b1c20062551e6cccc145ff" +checksum = "2e369bee1b05d510a7b4ed645f5faa90619e05437111783ea5848f28d97d3c2e" dependencies = [ "proc-macro2", "quote", @@ -1772,9 +2009,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "winnow" -version = "0.5.7" +version = "0.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19f495880723d0999eb3500a9064d8dbcf836460b24c17df80ea7b5794053aac" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" dependencies = [ "memchr", ] @@ -1857,20 +2094,21 @@ checksum = "2fe21bcc34ca7fe6dd56cc2cb1261ea59d6b93620215aefb5ea6032265527784" [[package]] name = "zerocopy" -version = "0.7.32" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.32" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.91", ] diff --git a/Cargo.toml b/Cargo.toml index 9eda94ff..d780f6c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ members = [ "kernel/comps/network", "kernel/comps/softirq", "kernel/comps/logger", + "kernel/comps/mlsdisk", "kernel/comps/time", "kernel/comps/virtio", "kernel/libs/cpio-decoder", diff --git a/Components.toml b/Components.toml index 866bd652..0c84da50 100644 --- a/Components.toml +++ b/Components.toml @@ -10,6 +10,7 @@ logger = { name = "aster-logger" } time = { name = "aster-time" } framebuffer = { name = "aster-framebuffer" } network = { name = "aster-network" } +mlsdisk = { name = "aster-mlsdisk" } [whitelist] [whitelist.nix.main] diff --git a/Makefile b/Makefile index 1fd16a08..6616579e 100644 --- a/Makefile +++ b/Makefile @@ -149,6 +149,7 @@ OSDK_CRATES := \ kernel/comps/network \ kernel/comps/softirq \ kernel/comps/logger \ + kernel/comps/mlsdisk \ kernel/comps/time \ kernel/comps/virtio \ kernel/libs/aster-util \ diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 5dabca61..c2ff90b1 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -14,6 +14,7 @@ aster-console = { path = "comps/console" } aster-framebuffer = { path = "comps/framebuffer" } aster-softirq = { path = "comps/softirq" } aster-logger = { path = "comps/logger" } +aster-mlsdisk = { path = "comps/mlsdisk" } aster-time = { path = "comps/time" } aster-virtio = { path = "comps/virtio" } aster-rights = { path = "libs/aster-rights" } diff --git a/kernel/comps/mlsdisk/Cargo.toml b/kernel/comps/mlsdisk/Cargo.toml new file mode 100644 index 00000000..897f3d56 --- /dev/null +++ b/kernel/comps/mlsdisk/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "aster-mlsdisk" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +inherit-methods-macro = {git = "https://github.com/asterinas/inherit-methods-macro", rev = "98f7e3e"} +ostd-pod = { git = "https://github.com/asterinas/ostd-pod", rev = "c4644be", version = "0.1.1" } +aster-block = { path = "../block" } +ostd = { path = "../../../ostd" } +aes-gcm = { version = "0.9.4", features = ["force-soft"] } +bittle = "0.5.6" +ctr = "0.8.0" +hashbrown = { version = "0.14.3", features = ["serde"] } +lending-iterator = "0.1.7" +log = "0.4" +lru = "0.12.3" +postcard = "1.0.6" +serde = { version = "1.0.192", default-features = false, features = ["alloc", "derive"] } +static_assertions = "1.1.0" diff --git a/kernel/comps/mlsdisk/src/error.rs b/kernel/comps/mlsdisk/src/error.rs new file mode 100644 index 00000000..c20dee94 --- /dev/null +++ b/kernel/comps/mlsdisk/src/error.rs @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::fmt; + +/// The error types used in this crate. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Errno { + /// Transaction aborted. + TxAborted, + /// Not found. + NotFound, + /// Invalid arguments. + InvalidArgs, + /// Out of memory. + OutOfMemory, + /// Out of disk space. + OutOfDisk, + /// IO error. + IoFailed, + /// Permission denied. + PermissionDenied, + /// Unsupported. + Unsupported, + /// OS-specific unknown error. + OsSpecUnknown, + /// Encryption operation failed. + EncryptFailed, + /// Decryption operation failed. + DecryptFailed, + /// MAC (Message Authentication Code) mismatched. + MacMismatched, + /// Not aligned to `BLOCK_SIZE`. + NotBlockSizeAligned, + /// Try lock failed. + TryLockFailed, +} + +/// The error with an error type and an error message used in this crate. +#[derive(Clone, Debug)] +pub struct Error { + errno: Errno, + msg: Option<&'static str>, +} + +impl Error { + /// Creates a new error with the given error type and no error message. + pub const fn new(errno: Errno) -> Self { + Error { errno, msg: None } + } + + /// Creates a new error with the given error type and the error message. + pub const fn with_msg(errno: Errno, msg: &'static str) -> Self { + Error { + errno, + msg: Some(msg), + } + } + + /// Returns the error type. + pub fn errno(&self) -> Errno { + self.errno + } +} + +impl From<Errno> for Error { + fn from(errno: Errno) -> Self { + Error::new(errno) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl fmt::Display for Errno { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +#[macro_export] +macro_rules! return_errno { + ($errno: expr) => { + return core::result::Result::Err($crate::Error::new($errno)) + }; +} + +#[macro_export] +macro_rules! return_errno_with_msg { + ($errno: expr, $msg: expr) => { + return core::result::Result::Err($crate::Error::with_msg($errno, $msg)) + }; +} diff --git a/kernel/comps/mlsdisk/src/layers/0-bio/block_buf.rs b/kernel/comps/mlsdisk/src/layers/0-bio/block_buf.rs new file mode 100644 index 00000000..e60d6972 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/0-bio/block_buf.rs @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This module provides API to represent buffers whose +//! sizes are block aligned. The advantage of using the +//! APIs provided this module over Rust std's counterparts +//! is to ensure the invariance of block-aligned length +//! at type level, eliminating the need for runtime check. +//! +//! There are three main types: +//! * `Buf`: A owned buffer backed by `Pages`, whose length is +//! a multiple of the block size. +//! * `BufRef`: An immutably-borrowed buffer whose length +//! is a multiple of the block size. +//! * `BufMut`: A mutably-borrowed buffer whose length is +//! a multiple of the block size. +//! +//! The basic usage is simple: replace the usage of `Box<[u8]>` +//! with `Buf`, `&[u8]` with `BufRef<[u8]>`, +//! and `&mut [u8]` with `BufMut<[u8]>`. + +use alloc::vec; +use core::convert::TryFrom; + +use lending_iterator::prelude::*; + +use super::BLOCK_SIZE; +use crate::prelude::*; + +/// A owned buffer whose length is a multiple of the block size. +pub struct Buf(Vec<u8>); + +impl Buf { + /// Allocate specific number of blocks as memory buffer. + pub fn alloc(num_blocks: usize) -> Result<Self> { + if num_blocks == 0 { + return_errno_with_msg!( + InvalidArgs, + "num_blocks must be greater than 0 for allocation" + ) + } + let buffer = vec![0; num_blocks * BLOCK_SIZE]; + Ok(Self(buffer)) + } + + /// Returns the number of blocks of owned buffer. + pub fn nblocks(&self) -> usize { + self.0.len() / BLOCK_SIZE + } + + /// Returns the immutable slice of owned buffer. + pub fn as_slice(&self) -> &[u8] { + self.0.as_slice() + } + + /// Returns the mutable slice of owned buffer. + pub fn as_mut_slice(&mut self) -> &mut [u8] { + self.0.as_mut_slice() + } + + /// Converts to immutably-borrowed buffer `BufRef`. + pub fn as_ref(&self) -> BufRef<'_> { + BufRef(self.as_slice()) + } + + /// Coverts to mutably-borrowed buffer `BufMut`. + pub fn as_mut(&mut self) -> BufMut<'_> { + BufMut(self.as_mut_slice()) + } +} + +/// An immutably-borrowed buffer whose length is a multiple of the block size. +#[derive(Clone, Copy)] +pub struct BufRef<'a>(&'a [u8]); + +impl BufRef<'_> { + /// Returns the immutable slice of borrowed buffer. + pub fn as_slice(&self) -> &[u8] { + self.0 + } + + /// Returns the number of blocks of borrowed buffer. + pub fn nblocks(&self) -> usize { + self.0.len() / BLOCK_SIZE + } + + /// Returns an iterator for immutable buffers of `BLOCK_SIZE`. + pub fn iter(&self) -> BufIter<'_> { + BufIter { + buf: BufRef(self.as_slice()), + offset: 0, + } + } +} + +impl<'a> TryFrom<&'a [u8]> for BufRef<'a> { + type Error = crate::error::Error; + + fn try_from(buf: &'a [u8]) -> Result<Self> { + if buf.is_empty() { + return_errno_with_msg!(InvalidArgs, "empty buf in `BufRef::try_from`"); + } + if buf.len() % BLOCK_SIZE != 0 { + return_errno_with_msg!( + NotBlockSizeAligned, + "buf not block size aligned `BufRef::try_from`" + ); + } + + let new_self = Self(buf); + Ok(new_self) + } +} + +/// A mutably-borrowed buffer whose length is a multiple of the block size. +pub struct BufMut<'a>(&'a mut [u8]); + +impl BufMut<'_> { + /// Returns the immutable slice of borrowed buffer. + pub fn as_slice(&self) -> &[u8] { + self.0 + } + + /// Returns the mutable slice of borrowed buffer. + pub fn as_mut_slice(&mut self) -> &mut [u8] { + self.0 + } + + /// Returns the number of blocks of borrowed buffer. + pub fn nblocks(&self) -> usize { + self.0.len() / BLOCK_SIZE + } + + /// Returns an iterator for immutable buffers of `BLOCK_SIZE`. + pub fn iter(&self) -> BufIter<'_> { + BufIter { + buf: BufRef(self.as_slice()), + offset: 0, + } + } + + /// Returns an iterator for mutable buffers of `BLOCK_SIZE`. + pub fn iter_mut(&mut self) -> BufIterMut<'_> { + BufIterMut { + buf: BufMut(self.as_mut_slice()), + offset: 0, + } + } +} + +impl<'a> TryFrom<&'a mut [u8]> for BufMut<'a> { + type Error = crate::error::Error; + + fn try_from(buf: &'a mut [u8]) -> Result<Self> { + if buf.is_empty() { + return_errno_with_msg!(InvalidArgs, "empty buf in `BufMut::try_from`"); + } + if buf.len() % BLOCK_SIZE != 0 { + return_errno_with_msg!( + NotBlockSizeAligned, + "buf not block size aligned `BufMut::try_from`" + ); + } + + let new_self = Self(buf); + Ok(new_self) + } +} + +/// Iterator for immutable buffers of `BLOCK_SIZE`. +pub struct BufIter<'a> { + buf: BufRef<'a>, + offset: usize, +} + +impl<'a> Iterator for BufIter<'a> { + type Item = BufRef<'a>; + + fn next(&mut self) -> Option<Self::Item> { + if self.offset >= self.buf.0.len() { + return None; + } + + let offset = self.offset; + self.offset += BLOCK_SIZE; + BufRef::try_from(&self.buf.0[offset..offset + BLOCK_SIZE]).ok() + } +} + +/// Iterator for mutable buffers of `BLOCK_SIZE`. +pub struct BufIterMut<'a> { + buf: BufMut<'a>, + offset: usize, +} + +#[gat] +impl LendingIterator for BufIterMut<'_> { + type Item<'next> = BufMut<'next>; + + fn next(&mut self) -> Option<Self::Item<'_>> { + if self.offset >= self.buf.0.len() { + return None; + } + + let offset = self.offset; + self.offset += BLOCK_SIZE; + BufMut::try_from(&mut self.buf.0[offset..offset + BLOCK_SIZE]).ok() + } +} + +#[cfg(test)] +mod tests { + use lending_iterator::LendingIterator; + + use super::{Buf, BufMut, BufRef, BLOCK_SIZE}; + + fn iterate_buf_ref<'a>(buf: BufRef<'a>) { + for block in buf.iter() { + assert_eq!(block.as_slice().len(), BLOCK_SIZE); + assert_eq!(block.nblocks(), 1); + } + } + + fn iterate_buf_mut<'a>(mut buf: BufMut<'a>) { + let mut iter_mut = buf.iter_mut(); + while let Some(mut block) = iter_mut.next() { + assert_eq!(block.as_mut_slice().len(), BLOCK_SIZE); + assert_eq!(block.nblocks(), 1); + } + } + + #[test] + fn buf() { + let mut buf = Buf::alloc(10).unwrap(); + assert_eq!(buf.nblocks(), 10); + assert_eq!(buf.as_slice().len(), 10 * BLOCK_SIZE); + iterate_buf_ref(buf.as_ref()); + iterate_buf_mut(buf.as_mut()); + + let mut buf = [0u8; BLOCK_SIZE]; + iterate_buf_ref(BufRef::try_from(buf.as_slice()).unwrap()); + iterate_buf_mut(BufMut::try_from(buf.as_mut_slice()).unwrap()); + } +} diff --git a/kernel/comps/mlsdisk/src/layers/0-bio/block_log.rs b/kernel/comps/mlsdisk/src/layers/0-bio/block_log.rs new file mode 100644 index 00000000..59cb5f36 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/0-bio/block_log.rs @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::sync::atomic::{AtomicUsize, Ordering}; + +use inherit_methods_macro::inherit_methods; + +use super::{Buf, BufMut, BufRef}; +use crate::{os::Mutex, prelude::*}; + +/// A log of data blocks that can support random reads and append-only +/// writes. +/// +/// # Thread safety +/// +/// `BlockLog` is a data structure of interior mutability. +/// It is ok to perform I/O on a `BlockLog` concurrently in multiple threads. +/// `BlockLog` promises the serialization of the append operations, i.e., +/// concurrent appends are carried out as if they are done one by one. +pub trait BlockLog: Sync + Send { + /// Read one or multiple blocks at a specified position. + fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>; + + /// Append one or multiple blocks at the end, + /// returning the ID of the first newly-appended block. + fn append(&self, buf: BufRef) -> Result<BlockId>; + + /// Ensure that blocks are persisted to the disk. + fn flush(&self) -> Result<()>; + + /// Returns the number of blocks. + fn nblocks(&self) -> usize; +} + +macro_rules! impl_blocklog_for { + ($typ:ty,$from:tt) => { + #[inherit_methods(from = $from)] + impl<T: BlockLog> BlockLog for $typ { + fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>; + fn append(&self, buf: BufRef) -> Result<BlockId>; + fn flush(&self) -> Result<()>; + fn nblocks(&self) -> usize; + } + }; +} + +impl_blocklog_for!(&T, "(**self)"); +impl_blocklog_for!(&mut T, "(**self)"); +impl_blocklog_for!(Box<T>, "(**self)"); +impl_blocklog_for!(Arc<T>, "(**self)"); + +/// An in-memory log that impls `BlockLog`. +pub struct MemLog { + log: Mutex<Buf>, + append_pos: AtomicUsize, +} + +impl BlockLog for MemLog { + fn read(&self, pos: BlockId, mut buf: BufMut) -> Result<()> { + let nblocks = buf.nblocks(); + if pos + nblocks > self.nblocks() { + return_errno_with_msg!(InvalidArgs, "read range out of bound"); + } + let log = self.log.lock(); + let read_buf = &log.as_slice()[Self::offset(pos)..Self::offset(pos) + nblocks * BLOCK_SIZE]; + buf.as_mut_slice().copy_from_slice(read_buf); + Ok(()) + } + + fn append(&self, buf: BufRef) -> Result<BlockId> { + let nblocks = buf.nblocks(); + let mut log = self.log.lock(); + let pos = self.append_pos.load(Ordering::Acquire); + if pos + nblocks > log.nblocks() { + return_errno_with_msg!(InvalidArgs, "append range out of bound"); + } + let write_buf = + &mut log.as_mut_slice()[Self::offset(pos)..Self::offset(pos) + nblocks * BLOCK_SIZE]; + write_buf.copy_from_slice(buf.as_slice()); + self.append_pos.fetch_add(nblocks, Ordering::Release); + Ok(pos) + } + + fn flush(&self) -> Result<()> { + Ok(()) + } + + fn nblocks(&self) -> usize { + self.append_pos.load(Ordering::Acquire) + } +} + +impl MemLog { + pub fn create(num_blocks: usize) -> Result<Self> { + Ok(Self { + log: Mutex::new(Buf::alloc(num_blocks)?), + append_pos: AtomicUsize::new(0), + }) + } + + fn offset(pos: BlockId) -> usize { + pos * BLOCK_SIZE + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mem_log() -> Result<()> { + let total_blocks = 64; + let append_nblocks = 8; + let mem_log = MemLog::create(total_blocks)?; + assert_eq!(mem_log.nblocks(), 0); + + let mut append_buf = Buf::alloc(append_nblocks)?; + let content = 5_u8; + append_buf.as_mut_slice().fill(content); + let append_pos = mem_log.append(append_buf.as_ref())?; + assert_eq!(append_pos, 0); + assert_eq!(mem_log.nblocks(), append_nblocks); + + mem_log.flush()?; + let mut read_buf = Buf::alloc(1)?; + let read_pos = 7 as BlockId; + mem_log.read(read_pos, read_buf.as_mut())?; + assert_eq!( + read_buf.as_slice(), + &append_buf.as_slice()[read_pos * BLOCK_SIZE..(read_pos + 1) * BLOCK_SIZE] + ); + Ok(()) + } +} diff --git a/kernel/comps/mlsdisk/src/layers/0-bio/block_ring.rs b/kernel/comps/mlsdisk/src/layers/0-bio/block_ring.rs new file mode 100644 index 00000000..2d03c9ef --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/0-bio/block_ring.rs @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::{BlockLog, BlockSet, BufMut, BufRef}; +use crate::{os::Mutex, prelude::*}; + +/// `BlockRing<S>` emulates a blocks log (`BlockLog`) with infinite +/// storage capacity by using a block set (`S: BlockSet`) of finite storage +/// capacity. +/// +/// `BlockRing<S>` uses the entire storage space provided by the underlying +/// block set (`S`) for user data, maintaining no extra metadata. +/// Having no metadata, `BlockRing<S>` has to put three responsibilities to +/// its user: +/// +/// 1. Tracking the valid block range for read. +/// `BlockRing<S>` accepts reads at any position regardless of whether the +/// position refers to a valid block. It blindly redirects the read request to +/// the underlying block set after moduloing the target position by the +/// size of the block set. +/// +/// 2. Setting the cursor for appending new blocks. +/// `BlockRing<S>` won't remember the progress of writing blocks after reboot. +/// Thus, after a `BlockRing<S>` is instantiated, the user must specify the +/// append cursor (using the `set_cursor` method) before appending new blocks. +/// +/// 3. Avoiding overriding valid data blocks mistakenly. +/// As the underlying storage is used in a ring buffer style, old +/// blocks must be overridden to accommodate new blocks. The user must ensure +/// that the underlying storage is big enough to avoid overriding any useful +/// data. +pub struct BlockRing<S> { + storage: S, + // The cursor for appending new blocks + cursor: Mutex<Option<BlockId>>, +} + +impl<S: BlockSet> BlockRing<S> { + /// Creates a new instance. + pub fn new(storage: S) -> Self { + Self { + storage, + cursor: Mutex::new(None), + } + } + + /// Set the cursor for appending new blocks. + /// + /// # Panics + /// + /// Calling the `append` method without setting the append cursor first + /// via this method `set_cursor` causes panic. + pub fn set_cursor(&self, new_cursor: BlockId) { + *self.cursor.lock() = Some(new_cursor); + } + + // Return a reference to the underlying storage. + pub fn storage(&self) -> &S { + &self.storage + } +} + +impl<S: BlockSet> BlockLog for BlockRing<S> { + fn read(&self, pos: BlockId, buf: BufMut) -> Result<()> { + let pos = pos % self.storage.nblocks(); + self.storage.read(pos, buf) + } + + fn append(&self, buf: BufRef) -> Result<BlockId> { + let cursor = self + .cursor + .lock() + .expect("cursor must be set before appending new blocks"); + let pos = cursor % self.storage.nblocks(); + let new_cursor = cursor + buf.nblocks(); + self.storage.write(pos, buf)?; + self.set_cursor(new_cursor); + Ok(cursor) + } + + fn flush(&self) -> Result<()> { + self.storage.flush() + } + + fn nblocks(&self) -> usize { + self.cursor.lock().unwrap_or(0) + } +} + +#[cfg(test)] +mod tests { + use super::BlockRing; + use crate::layers::bio::{BlockLog, Buf, MemDisk}; + + #[test] + fn block_ring() { + let num_blocks = 16; + let disk = MemDisk::create(num_blocks).unwrap(); + let block_ring = BlockRing::new(disk); + block_ring.set_cursor(num_blocks); + assert_eq!(block_ring.nblocks(), num_blocks); + + let mut append_buf = Buf::alloc(1).unwrap(); + append_buf.as_mut_slice().fill(1); + let pos = block_ring.append(append_buf.as_ref()).unwrap(); + assert_eq!(pos, num_blocks); + assert_eq!(block_ring.nblocks(), num_blocks + 1); + + let mut read_buf = Buf::alloc(1).unwrap(); + block_ring + .read(pos % num_blocks, read_buf.as_mut()) + .unwrap(); + assert_eq!(read_buf.as_slice(), append_buf.as_slice()); + } +} diff --git a/kernel/comps/mlsdisk/src/layers/0-bio/block_set.rs b/kernel/comps/mlsdisk/src/layers/0-bio/block_set.rs new file mode 100644 index 00000000..ed0604d6 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/0-bio/block_set.rs @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::ops::Range; + +use inherit_methods_macro::inherit_methods; + +use super::{Buf, BufMut, BufRef}; +use crate::{error::Errno, os::Mutex, prelude::*}; + +/// A fixed set of data blocks that can support random reads and writes. +/// +/// # Thread safety +/// +/// `BlockSet` is a data structure of interior mutability. +/// It is ok to perform I/O on a `BlockSet` concurrently in multiple threads. +/// `BlockSet` promises the atomicity of reading and writing individual blocks. +pub trait BlockSet: Sync + Send { + /// Read one or multiple blocks at a specified position. + fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>; + + /// Read a slice of bytes at a specified byte offset. + fn read_slice(&self, offset: usize, buf: &mut [u8]) -> Result<()> { + let start_pos = offset / BLOCK_SIZE; + let end_pos = (offset + buf.len()).div_ceil(BLOCK_SIZE); + if end_pos > self.nblocks() { + return_errno_with_msg!(Errno::InvalidArgs, "read_slice position is out of range"); + } + + let nblocks = end_pos - start_pos; + let mut blocks = Buf::alloc(nblocks)?; + self.read(start_pos, blocks.as_mut())?; + + let offset = offset % BLOCK_SIZE; + buf.copy_from_slice(&blocks.as_slice()[offset..offset + buf.len()]); + Ok(()) + } + + /// Write one or multiple blocks at a specified position. + fn write(&self, pos: BlockId, buf: BufRef) -> Result<()>; + + /// Write a slice of bytes at a specified byte offset. + fn write_slice(&self, offset: usize, buf: &[u8]) -> Result<()> { + let start_pos = offset / BLOCK_SIZE; + let end_pos = (offset + buf.len()).div_ceil(BLOCK_SIZE); + if end_pos > self.nblocks() { + return_errno_with_msg!(Errno::InvalidArgs, "write_slice position is out of range"); + } + let nblocks = end_pos - start_pos; + let mut blocks = Buf::alloc(nblocks)?; + + // Maybe we should read the first block partially. + let start_offset = offset % BLOCK_SIZE; + if start_offset != 0 { + let mut start_block = Buf::alloc(1)?; + self.read(start_pos, start_block.as_mut())?; + blocks.as_mut_slice()[..start_offset] + .copy_from_slice(&start_block.as_slice()[..start_offset]); + } + + // Copy the input buffer to the write buffer. + let end_offset = start_offset + buf.len(); + blocks.as_mut_slice()[start_offset..end_offset].copy_from_slice(buf); + + // Maybe we should read the last block partially. + if end_offset % BLOCK_SIZE != 0 { + let mut end_block = Buf::alloc(1)?; + self.read(end_pos, end_block.as_mut())?; + blocks.as_mut_slice()[end_offset..] + .copy_from_slice(&end_block.as_slice()[end_offset % BLOCK_SIZE..]); + } + + // Write blocks. + self.write(start_pos, blocks.as_ref())?; + Ok(()) + } + + /// Get a subset of the blocks in the block set. + fn subset(&self, range: Range<BlockId>) -> Result<Self> + where + Self: Sized; + + /// Ensure that blocks are persisted to the disk. + fn flush(&self) -> Result<()>; + + /// Returns the number of blocks. + fn nblocks(&self) -> usize; +} + +macro_rules! impl_blockset_for { + ($typ:ty,$from:tt,$subset_fn:expr) => { + #[inherit_methods(from = $from)] + impl<T: BlockSet> BlockSet for $typ { + fn read(&self, pos: BlockId, buf: BufMut) -> Result<()>; + fn read_slice(&self, offset: usize, buf: &mut [u8]) -> Result<()>; + fn write(&self, pos: BlockId, buf: BufRef) -> Result<()>; + fn write_slice(&self, offset: usize, buf: &[u8]) -> Result<()>; + fn flush(&self) -> Result<()>; + fn nblocks(&self) -> usize; + fn subset(&self, range: Range<BlockId>) -> Result<Self> { + let closure = $subset_fn; + closure(self, range) + } + } + }; +} + +impl_blockset_for!(&T, "(**self)", |_this, _range| { + return_errno_with_msg!(Errno::NotFound, "cannot return `Self` by `subset` of `&T`"); +}); + +impl_blockset_for!(&mut T, "(**self)", |_this, _range| { + return_errno_with_msg!( + Errno::NotFound, + "cannot return `Self` by `subset` of `&mut T`" + ); +}); + +impl_blockset_for!(Box<T>, "(**self)", |this: &T, range| { + this.subset(range).map(|v| Box::new(v)) +}); + +impl_blockset_for!(Arc<T>, "(**self)", |this: &Arc<T>, range| { + (**this).subset(range).map(|v| Arc::new(v)) +}); + +/// A disk that impl `BlockSet`. +/// +/// The `region` is the accessible subset. +#[derive(Clone)] +pub struct MemDisk { + disk: Arc<Mutex<Buf>>, + region: Range<BlockId>, +} + +impl MemDisk { + /// Create a `MemDisk` with the number of blocks. + pub fn create(num_blocks: usize) -> Result<Self> { + let blocks = Buf::alloc(num_blocks)?; + Ok(Self { + disk: Arc::new(Mutex::new(blocks)), + region: Range { + start: 0, + end: num_blocks, + }, + }) + } +} + +impl BlockSet for MemDisk { + fn read(&self, pos: BlockId, mut buf: BufMut) -> Result<()> { + if pos + buf.nblocks() > self.region.end { + return_errno_with_msg!(Errno::InvalidArgs, "read position is out of range"); + } + let offset = (self.region.start + pos) * BLOCK_SIZE; + let buf_len = buf.as_slice().len(); + + let disk = self.disk.lock(); + buf.as_mut_slice() + .copy_from_slice(&disk.as_slice()[offset..offset + buf_len]); + Ok(()) + } + + fn write(&self, pos: BlockId, buf: BufRef) -> Result<()> { + if pos + buf.nblocks() > self.region.end { + return_errno_with_msg!(Errno::InvalidArgs, "write position is out of range"); + } + let offset = (self.region.start + pos) * BLOCK_SIZE; + let buf_len = buf.as_slice().len(); + + let mut disk = self.disk.lock(); + disk.as_mut_slice()[offset..offset + buf_len].copy_from_slice(buf.as_slice()); + Ok(()) + } + + fn subset(&self, range: Range<BlockId>) -> Result<Self> { + if self.region.start + range.end > self.region.end { + return_errno_with_msg!(Errno::InvalidArgs, "subset is out of range"); + } + + Ok(MemDisk { + disk: self.disk.clone(), + region: Range { + start: self.region.start + range.start, + end: self.region.start + range.end, + }, + }) + } + + fn flush(&self) -> Result<()> { + Ok(()) + } + + fn nblocks(&self) -> usize { + self.region.len() + } +} + +#[cfg(test)] +mod tests { + use core::ops::Range; + + use crate::layers::bio::{BlockSet, Buf, MemDisk}; + + #[test] + fn mem_disk() { + let num_blocks = 64; + let disk = MemDisk::create(num_blocks).unwrap(); + assert_eq!(disk.nblocks(), 64); + + let mut buf = Buf::alloc(1).unwrap(); + buf.as_mut_slice().fill(1); + disk.write(32, buf.as_ref()).unwrap(); + + let range = Range { start: 32, end: 64 }; + let subset = disk.subset(range).unwrap(); + assert_eq!(subset.nblocks(), 32); + + buf.as_mut_slice().fill(0); + subset.read(0, buf.as_mut()).unwrap(); + assert_eq!(buf.as_ref().as_slice(), [1u8; 4096]); + + subset.write_slice(4096 - 4, &[2u8; 8]).unwrap(); + let mut buf = [0u8; 16]; + subset.read_slice(4096 - 8, &mut buf).unwrap(); + assert_eq!(buf, [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]); + } +} diff --git a/kernel/comps/mlsdisk/src/layers/0-bio/mod.rs b/kernel/comps/mlsdisk/src/layers/0-bio/mod.rs new file mode 100644 index 00000000..9c6dbece --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/0-bio/mod.rs @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! The layer of untrusted block I/O. + +use static_assertions::assert_eq_size; + +mod block_buf; +mod block_log; +mod block_ring; +mod block_set; + +pub use self::{ + block_buf::{Buf, BufMut, BufRef}, + block_log::{BlockLog, MemLog}, + block_ring::BlockRing, + block_set::{BlockSet, MemDisk}, +}; + +pub type BlockId = usize; +pub const BLOCK_SIZE: usize = 0x1000; +pub const BID_SIZE: usize = core::mem::size_of::<BlockId>(); + +// This definition of BlockId assumes the target architecture is 64-bit +assert_eq_size!(usize, u64); diff --git a/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_blob.rs b/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_blob.rs new file mode 100644 index 00000000..c35bbe9e --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_blob.rs @@ -0,0 +1,358 @@ +// SPDX-License-Identifier: MPL-2.0 + +use ostd_pod::Pod; + +use super::{Iv, Key, Mac, VersionId}; +use crate::{ + layers::bio::{BlockSet, Buf, BLOCK_SIZE}, + os::{Aead, Mutex}, + prelude::*, +}; + +/// A cryptographically-protected blob of user data. +/// +/// `CryptoBlob<B>` allows a variable-length of user data to be securely +/// written to and read from a fixed, pre-allocated block set +/// (represented by `B: BlockSet`) on disk. Obviously, the length of user data +/// must be smaller than that of the block set. +/// +/// # On-disk format +/// +/// The on-disk format of `CryptoBlob` is shown below. +/// +/// ``` +/// ┌─────────┬─────────┬─────────┬──────────────────────────────┐ +/// │VersionId│ MAC │ Length │ Encrypted Payload │ +/// │ (8B) │ (16B) │ (8B) │ (Length bytes) │ +/// └─────────┴─────────┴─────────┴──────────────────────────────┘ +/// ``` +/// +/// The version ID increments by one each time the `CryptoBlob` is updated. +/// The MAC protects the integrity of the length and the encrypted payload. +/// +/// # Security +/// +/// To ensure the confidentiality and integrity of user data, `CryptoBlob` +/// takes several measures: +/// +/// 1. Each instance of `CryptoBlob` is associated with a randomly-generated, +/// unique encryption key. +/// 2. Each instance of `CryptoBlob` maintains a version ID, which is +/// automatically incremented by one upon each write. +/// 3. The user data written to a `CryptoBlob` is protected with authenticated +/// encryption before being persisted to the disk. +/// The encryption takes the current version ID as the IV and generates a MAC +/// as the output. +/// 4. To read user data from a `CryptoBlob`, it first decrypts +/// the untrusted on-disk data with the encryption key associated with this object +/// and validating its integrity. Optimally, the user can check the version ID +/// of the decrypted user data and see if the version ID is up-to-date. +/// +pub struct CryptoBlob<B> { + block_set: B, + key: Key, + header: Mutex<Option<Header>>, +} + +#[repr(C)] +#[derive(Copy, Clone, Pod)] +struct Header { + version: VersionId, + mac: Mac, + payload_len: usize, +} + +impl<B: BlockSet> CryptoBlob<B> { + /// The size of the header of a crypto blob in bytes. + pub const HEADER_NBYTES: usize = core::mem::size_of::<Header>(); + + /// Opens an existing `CryptoBlob`. + /// + /// The capacity of this `CryptoBlob` object is determined by the size + /// of `block_set: B`. + pub fn open(key: Key, block_set: B) -> Self { + Self { + block_set, + key, + header: Mutex::new(None), + } + } + + /// Creates a new `CryptoBlob`. + /// + /// The encryption key of a `CryptoBlob` is generated randomly so that + /// no two `CryptoBlob` instances shall ever use the same key. + pub fn create(block_set: B, init_data: &[u8]) -> Result<Self> { + let capacity = block_set.nblocks() * BLOCK_SIZE - Self::HEADER_NBYTES; + if init_data.len() > capacity { + return_errno_with_msg!(OutOfDisk, "init_data is too large"); + } + let nblocks = (Self::HEADER_NBYTES + init_data.len()).div_ceil(BLOCK_SIZE); + let mut block_buf = Buf::alloc(nblocks)?; + + // Encrypt init_data. + let aead = Aead::new(); + let key = Key::random(); + let version: VersionId = 0; + let mut iv = Iv::new_zeroed(); + iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes()); + let output = &mut block_buf.as_mut_slice() + [Self::HEADER_NBYTES..Self::HEADER_NBYTES + init_data.len()]; + let mac = aead.encrypt(init_data, &key, &iv, &[], output)?; + + // Store header. + let header = Header { + version, + mac, + payload_len: init_data.len(), + }; + block_buf.as_mut_slice()[..Self::HEADER_NBYTES].copy_from_slice(header.as_bytes()); + + // Write to `BlockSet`. + block_set.write(0, block_buf.as_ref())?; + Ok(Self { + block_set, + key, + header: Mutex::new(Some(header)), + }) + } + + /// Write the buffer to the disk as the latest version of the content of + /// this `CryptoBlob`. + /// + /// The size of the buffer must not be greater than the capacity of this + /// `CryptoBlob`. + /// + /// Each successful write increments the version ID by one. If + /// there is no valid version ID, an `Error` will be returned. + /// User could get a version ID, either by a successful call to + /// `read`, or `recover_from` another valid `CryptoBlob`. + /// + /// # Security + /// + /// This content is guaranteed to be confidential as long as the key is not + /// known to an attacker. + pub fn write(&mut self, buf: &[u8]) -> Result<VersionId> { + if buf.len() > self.capacity() { + return_errno_with_msg!(OutOfDisk, "write data is too large"); + } + let nblocks = (Self::HEADER_NBYTES + buf.len()).div_ceil(BLOCK_SIZE); + let mut block_buf = Buf::alloc(nblocks)?; + + // Encrypt payload. + let aead = Aead::new(); + let version = match self.version_id() { + Some(version) => version + 1, + None => return_errno_with_msg!(NotFound, "write with no valid version ID"), + }; + let mut iv = Iv::new_zeroed(); + iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes()); + let output = + &mut block_buf.as_mut_slice()[Self::HEADER_NBYTES..Self::HEADER_NBYTES + buf.len()]; + let mac = aead.encrypt(buf, &self.key, &iv, &[], output)?; + + // Store header. + let header = Header { + version, + mac, + payload_len: buf.len(), + }; + block_buf.as_mut_slice()[..Self::HEADER_NBYTES].copy_from_slice(header.as_bytes()); + + // Write to `BlockSet`. + self.block_set.write(0, block_buf.as_ref())?; + *self.header.lock() = Some(header); + Ok(version) + } + + /// Read the content of the `CryptoBlob` from the disk into the buffer. + /// + /// The given buffer must has a length that is no less than the size of + /// the plaintext content of this `CryptoBlob`. + /// + /// # Security + /// + /// This content, including its length, is guaranteed to be authentic. + pub fn read(&self, buf: &mut [u8]) -> Result<usize> { + let header = match *self.header.lock() { + Some(header) => header, + None => { + let mut header = Header::new_zeroed(); + self.block_set.read_slice(0, header.as_bytes_mut())?; + header + } + }; + if header.payload_len > self.capacity() { + return_errno_with_msg!(OutOfDisk, "payload_len is greater than the capacity"); + } + if header.payload_len > buf.len() { + return_errno_with_msg!(OutOfDisk, "read_buf is too small"); + } + let nblock = (Self::HEADER_NBYTES + header.payload_len).div_ceil(BLOCK_SIZE); + let mut block_buf = Buf::alloc(nblock)?; + self.block_set.read(0, block_buf.as_mut())?; + + // Decrypt payload. + let aead = Aead::new(); + let version = header.version; + let mut iv = Iv::new_zeroed(); + iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes()); + let input = + &block_buf.as_slice()[Self::HEADER_NBYTES..Self::HEADER_NBYTES + header.payload_len]; + let output = &mut buf[..header.payload_len]; + aead.decrypt(input, &self.key, &iv, &[], &header.mac, output)?; + *self.header.lock() = Some(header); + Ok(header.payload_len) + } + + /// Returns the key associated with this `CryptoBlob`. + pub fn key(&self) -> &Key { + &self.key + } + + /// Returns the current version ID. + /// + /// # Security + /// + /// It is valid after a successful call to `create`, `read` or `write`. + /// User could also get a version ID from another valid `CryptoBlob`, + /// (usually a backup), through method `recover_from`. + pub fn version_id(&self) -> Option<VersionId> { + self.header.lock().map(|header| header.version) + } + + /// Recover from another `CryptoBlob`. + /// + /// If `CryptoBlob` doesn't have a valid version ID, e.g., payload decryption + /// failed when `read`, user could call this method to recover version ID and + /// payload from another `CryptoBlob` (usually a backup). + pub fn recover_from(&mut self, other: &CryptoBlob<B>) -> Result<()> { + if self.capacity() != other.capacity() { + return_errno_with_msg!(InvalidArgs, "capacity not aligned, recover failed"); + } + if self.header.lock().is_some() { + return_errno_with_msg!(InvalidArgs, "no need to recover"); + } + let nblocks = self.block_set.nblocks(); + // Read version ID and payload from another `CryptoBlob`. + let mut read_buf = Buf::alloc(nblocks)?; + let payload_len = other.read(read_buf.as_mut_slice())?; + let version = other.version_id().unwrap(); + + // Encrypt payload. + let aead = Aead::new(); + let mut iv = Iv::new_zeroed(); + iv.as_bytes_mut()[..version.as_bytes().len()].copy_from_slice(version.as_bytes()); + let input = &read_buf.as_slice()[..payload_len]; + let mut write_buf = Buf::alloc(nblocks)?; + let output = + &mut write_buf.as_mut_slice()[Self::HEADER_NBYTES..Self::HEADER_NBYTES + payload_len]; + let mac = aead.encrypt(input, self.key(), &iv, &[], output)?; + + // Store header. + let header = Header { + version, + mac, + payload_len, + }; + write_buf.as_mut_slice()[..Self::HEADER_NBYTES].copy_from_slice(header.as_bytes()); + + // Write to `BlockSet`. + self.block_set.write(0, write_buf.as_ref())?; + *self.header.lock() = Some(header); + Ok(()) + } + + /// Returns the current MAC of encrypted payload. + /// + /// # Security + /// + /// It is valid after a successful call to `create`, `read` or `write`. + pub fn current_mac(&self) -> Option<Mac> { + self.header.lock().map(|header| header.mac) + } + + /// Returns the capacity of this `CryptoBlob` in bytes. + pub fn capacity(&self) -> usize { + self.block_set.nblocks() * BLOCK_SIZE - Self::HEADER_NBYTES + } + + /// Returns the number of blocks occupied by the underlying `BlockSet`. + pub fn nblocks(&self) -> usize { + self.block_set.nblocks() + } +} + +#[cfg(test)] +mod tests { + use super::CryptoBlob; + use crate::layers::bio::{BlockSet, MemDisk, BLOCK_SIZE}; + + #[test] + fn create() { + let disk = MemDisk::create(2).unwrap(); + let init_data = [1u8; BLOCK_SIZE]; + let blob = CryptoBlob::create(disk, &init_data).unwrap(); + + println!("blob key: {:?}", blob.key()); + assert_eq!(blob.version_id(), Some(0)); + assert_eq!(blob.nblocks(), 2); + assert_eq!( + blob.capacity(), + 2 * BLOCK_SIZE - CryptoBlob::<MemDisk>::HEADER_NBYTES + ); + } + + #[test] + fn open_and_read() { + let disk = MemDisk::create(4).unwrap(); + let key = { + let subset = disk.subset(0..2).unwrap(); + let init_data = [1u8; 1024]; + let blob = CryptoBlob::create(subset, &init_data).unwrap(); + blob.key + }; + + let subset = disk.subset(0..2).unwrap(); + let blob = CryptoBlob::open(key, subset); + assert_eq!(blob.version_id(), None); + assert_eq!(blob.nblocks(), 2); + let mut buf = [0u8; BLOCK_SIZE]; + let payload_len = blob.read(&mut buf).unwrap(); + assert_eq!(buf[..payload_len], [1u8; 1024]); + } + + #[test] + fn write() { + let disk = MemDisk::create(2).unwrap(); + let init_data = [0u8; BLOCK_SIZE]; + let mut blob = CryptoBlob::create(disk, &init_data).unwrap(); + + let write_buf = [1u8; 1024]; + blob.write(&write_buf).unwrap(); + let mut read_buf = [0u8; 1024]; + blob.read(&mut read_buf).unwrap(); + assert_eq!(read_buf, [1u8; 1024]); + assert_eq!(blob.version_id(), Some(1)); + } + + #[test] + fn recover_from() { + let disk = MemDisk::create(2).unwrap(); + let init_data = [1u8; 1024]; + let subset0 = disk.subset(0..1).unwrap(); + let mut blob0 = CryptoBlob::create(subset0, &init_data).unwrap(); + assert_eq!(blob0.version_id(), Some(0)); + blob0.write(&init_data).unwrap(); + assert_eq!(blob0.version_id(), Some(1)); + + let subset1 = disk.subset(1..2).unwrap(); + let mut blob1 = CryptoBlob::open(blob0.key, subset1); + assert_eq!(blob1.version_id(), None); + blob1.recover_from(&blob0).unwrap(); + let mut read_buf = [0u8; BLOCK_SIZE]; + let payload_len = blob1.read(&mut read_buf).unwrap(); + assert_eq!(read_buf[..payload_len], [1u8; 1024]); + assert_eq!(blob1.version_id(), Some(1)); + } +} diff --git a/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_chain.rs b/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_chain.rs new file mode 100644 index 00000000..eb3feb4d --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_chain.rs @@ -0,0 +1,401 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::ops::Range; + +use lending_iterator::prelude::*; +use ostd_pod::Pod; + +use super::{Iv, Key, Mac}; +use crate::{ + layers::bio::{BlockId, BlockLog, Buf, BLOCK_SIZE}, + os::Aead, + prelude::*, +}; + +/// A cryptographically-protected chain of blocks. +/// +/// `CryptoChain<L>` allows writing and reading a sequence of +/// consecutive blocks securely to and from an untrusted storage of data log +/// `L: BlockLog`. +/// The target use case of `CryptoChain` is to implement secure journals, +/// where old data are scanned and new data are appended. +/// +/// # On-disk format +/// +/// The on-disk format of each block is shown below. +/// +/// ```text +/// ┌─────────────────────┬───────┬──────────┬──────────┬──────────┬─────────┐ +/// │ Encrypted payload │ Gap │ Length │ PreMac │ CurrMac │ IV │ +/// │(Length <= 4KB - 48B)│ │ (4B) │ (16B) │ (16B) │ (12B) │ +/// └─────────────────────┴───────┴──────────┴──────────┴──────────┴─────────┘ +/// +/// ◄─────────────────────────── Block size (4KB) ──────────────────────────► +/// ``` +/// +/// Each block begins with encrypted user payload. The size of payload +/// must be smaller than that of block size as each block ends with a footer +/// (in plaintext). +/// The footer consists of fours parts: the length of the payload (in bytes), +/// the MAC of the previous block, the MAC of the current block, the IV used +/// for encrypting the current block. +/// The MAC of a block protects the encrypted payload, its length, and the MAC +/// of the previous block. +/// +/// # Security +/// +/// Each `CryptoChain` is assigned a randomly-generated encryption key. +/// Each block is encrypted using this key and a randomly-generated IV. +/// This setup ensures the confidentiality of payload and even the same payloads +/// result in different ciphertexts. +/// +/// `CryptoChain` is called a "chain" of blocks because each block +/// not only stores its own MAC, but also the MAC of its previous block. +/// This effectively forms a "chain" (much like a blockchain), +/// ensuring the orderness and consecutiveness of the sequence of blocks. +/// +/// Due to this chain structure, the integrity of a `CryptoChain` can be ensured +/// by verifying the MAC of the last block. Once the integrity of the last block +/// is verified, the integrity of all previous blocks can also be verified. +pub struct CryptoChain<L> { + block_log: L, + key: Key, + block_range: Range<BlockId>, + block_macs: Vec<Mac>, +} + +#[repr(C)] +#[derive(Copy, Clone, Pod)] +struct Footer { + len: u32, + pre_mac: Mac, + this_mac: Mac, + this_iv: Iv, +} + +impl<L: BlockLog> CryptoChain<L> { + /// The available size in each chained block is smaller than that of + /// the block size. + pub const AVAIL_BLOCK_SIZE: usize = BLOCK_SIZE - core::mem::size_of::<Footer>(); + + /// Construct a new `CryptoChain` using `block_log: L` as the storage. + pub fn new(block_log: L) -> Self { + Self { + block_log, + block_range: 0..0, + key: Key::random(), + block_macs: Vec::new(), + } + } + + /// Recover an existing `CryptoChain` backed by `block_log: L`, + /// starting from its `from` block. + pub fn recover(key: Key, block_log: L, from: BlockId) -> Recovery<L> { + Recovery::new(block_log, key, from) + } + + /// Read a block at a specified position. + /// + /// The length of the given buffer should not be smaller than payload_len + /// stored in `Footer`. + /// + /// # Security + /// + /// The authenticity of the block is guaranteed. + pub fn read(&self, pos: BlockId, buf: &mut [u8]) -> Result<usize> { + if !self.block_range().contains(&pos) { + return_errno_with_msg!(NotFound, "read position is out of range"); + } + + // Read block and get footer. + let mut block_buf = Buf::alloc(1)?; + self.block_log.read(pos, block_buf.as_mut())?; + let footer: Footer = Pod::from_bytes(&block_buf.as_slice()[Self::AVAIL_BLOCK_SIZE..]); + + let payload_len = footer.len as usize; + if payload_len > Self::AVAIL_BLOCK_SIZE || payload_len > buf.len() { + return_errno_with_msg!(OutOfDisk, "wrong payload_len or the read_buf is too small"); + } + + // Check the footer MAC, to ensure the orderness and consecutiveness of blocks. + let this_mac = self.block_macs.get(pos - self.block_range.start).unwrap(); + if footer.this_mac.as_bytes() != this_mac.as_bytes() { + return_errno_with_msg!(NotFound, "check footer MAC failed"); + } + + // Decrypt payload. + let aead = Aead::new(); + aead.decrypt( + &block_buf.as_slice()[..payload_len], + self.key(), + &footer.this_iv, + &footer.pre_mac, + &footer.this_mac, + &mut buf[..payload_len], + )?; + Ok(payload_len) + } + + /// Append a block at the end. + /// + /// The length of the given buffer must not be larger than `AVAIL_BLOCK_SIZE`. + /// + /// # Security + /// + /// The confidentiality of the block is guaranteed. + pub fn append(&mut self, buf: &[u8]) -> Result<()> { + if buf.len() > Self::AVAIL_BLOCK_SIZE { + return_errno_with_msg!(OutOfDisk, "append data is too large"); + } + let mut block_buf = Buf::alloc(1)?; + + // Encrypt payload. + let aead = Aead::new(); + let this_iv = Iv::random(); + let pre_mac = self.block_macs.last().copied().unwrap_or_default(); + let output = &mut block_buf.as_mut_slice()[..buf.len()]; + let this_mac = aead.encrypt(buf, self.key(), &this_iv, &pre_mac, output)?; + + // Store footer. + let footer = Footer { + len: buf.len() as _, + pre_mac, + this_mac, + this_iv, + }; + let buf = &mut block_buf.as_mut_slice()[Self::AVAIL_BLOCK_SIZE..]; + buf.copy_from_slice(footer.as_bytes()); + + self.block_log.append(block_buf.as_ref())?; + self.block_range.end += 1; + self.block_macs.push(this_mac); + Ok(()) + } + + /// Ensures the persistence of data. + pub fn flush(&self) -> Result<()> { + self.block_log.flush() + } + + /// Trim the blocks before a specified position (exclusive). + /// + /// The purpose of this method is to free some memory used for keeping the + /// MACs of accessible blocks. After trimming, the range of accessible + /// blocks is shrunk accordingly. + pub fn trim(&mut self, before_block: BlockId) { + // We must ensure the invariance that there is at least one valid block + // after trimming. + debug_assert!(before_block < self.block_range.end); + + if before_block <= self.block_range.start { + return; + } + + let num_blocks_trimmed = before_block - self.block_range.start; + self.block_range.start = before_block; + self.block_macs.drain(..num_blocks_trimmed); + } + + /// Returns the range of blocks that are accessible through the `CryptoChain`. + pub fn block_range(&self) -> &Range<BlockId> { + &self.block_range + } + + /// Returns the underlying block log. + pub fn inner_log(&self) -> &L { + &self.block_log + } + + /// Returns the encryption key of the `CryptoChain`. + pub fn key(&self) -> &Key { + &self.key + } +} + +/// `Recovery<L>` represents an instance `CryptoChain<L>` being recovered. +/// +/// An object `Recovery<L>` attempts to recover as many valid blocks of +/// a `CryptoChain` as possible. A block is valid if and only if its real MAC +/// is equal to the MAC value recorded in its successor. +/// +/// For the last block, which does not have a successor block, the user +/// can obtain its MAC from `Recovery<L>` and verify the MAC by comparing it +/// with an expected value from another trusted source. +pub struct Recovery<L> { + block_log: L, + key: Key, + block_range: Range<BlockId>, + block_macs: Vec<Mac>, + read_buf: Buf, + payload: Buf, +} + +impl<L: BlockLog> Recovery<L> { + /// Construct a new `Recovery` from the `first_block` of + /// `block_log: L`, using a cryptographic `key`. + pub fn new(block_log: L, key: Key, first_block: BlockId) -> Self { + Self { + block_log, + key, + block_range: first_block..first_block, + block_macs: Vec::new(), + read_buf: Buf::alloc(1).unwrap(), + payload: Buf::alloc(1).unwrap(), + } + } + + /// Returns the number of valid blocks. + /// + /// Each success call to `next` increments the number of valid blocks. + pub fn num_blocks(&self) -> usize { + self.block_range.len() + } + + /// Returns the range of valid blocks. + /// + /// Each success call to `next` increments the upper bound by one. + pub fn block_range(&self) -> &Range<BlockId> { + &self.block_range + } + + /// Returns the MACs of valid blocks. + /// + /// Each success call to `next` pushes the MAC of the new valid block. + pub fn block_macs(&self) -> &[Mac] { + &self.block_macs + } + + /// Open a `CryptoChain<L>` from the recovery object. + /// + /// User should call `next` to retrieve valid blocks as much as possible. + pub fn open(self) -> CryptoChain<L> { + CryptoChain { + block_log: self.block_log, + key: self.key, + block_range: self.block_range, + block_macs: self.block_macs, + } + } +} + +#[gat] +impl<L: BlockLog> LendingIterator for Recovery<L> { + type Item<'a> = &'a [u8]; + + fn next(&mut self) -> Option<Self::Item<'_>> { + let next_block_id = self.block_range.end; + self.block_log + .read(next_block_id, self.read_buf.as_mut()) + .ok()?; + + // Deserialize footer. + let footer: Footer = + Pod::from_bytes(&self.read_buf.as_slice()[CryptoChain::<L>::AVAIL_BLOCK_SIZE..]); + let payload_len = footer.len as usize; + if payload_len > CryptoChain::<L>::AVAIL_BLOCK_SIZE { + return None; + } + + // Decrypt payload. + let aead = Aead::new(); + aead.decrypt( + &self.read_buf.as_slice()[..payload_len], + &self.key, + &footer.this_iv, + &footer.pre_mac, + &footer.this_mac, + &mut self.payload.as_mut_slice()[..payload_len], + ) + .ok()?; + + // Crypto blocks are chained: each block stores not only + // the MAC of its own, but also the MAC of its previous block. + // So we need to check whether the two MAC values are the same. + // There is one exception that the `pre_mac` of the first block + // is NOT checked. + if self + .block_macs() + .last() + .is_some_and(|mac| mac.as_bytes() != footer.pre_mac.as_bytes()) + { + return None; + } + + self.block_range.end += 1; + self.block_macs.push(footer.this_mac); + Some(&self.payload.as_slice()[..payload_len]) + } +} + +#[cfg(test)] +mod tests { + use lending_iterator::LendingIterator; + + use super::CryptoChain; + use crate::layers::bio::{BlockLog, BlockRing, BlockSet, MemDisk}; + + #[test] + fn new() { + let disk = MemDisk::create(16).unwrap(); + let block_ring = BlockRing::new(disk); + block_ring.set_cursor(0); + let chain = CryptoChain::new(block_ring); + + assert_eq!(chain.block_log.nblocks(), 0); + assert_eq!(chain.block_range.start, 0); + assert_eq!(chain.block_range.end, 0); + assert_eq!(chain.block_macs.len(), 0); + } + + #[test] + fn append_trim_and_read() { + let disk = MemDisk::create(16).unwrap(); + let block_ring = BlockRing::new(disk); + block_ring.set_cursor(0); + let mut chain = CryptoChain::new(block_ring); + + let data = [1u8; 1024]; + chain.append(&data[..256]).unwrap(); + chain.append(&data[..512]).unwrap(); + assert_eq!(chain.block_range.end, 2); + assert_eq!(chain.block_macs.len(), 2); + + chain.trim(1); + + assert_eq!(chain.block_range.start, 1); + assert_eq!(chain.block_range.end, 2); + assert_eq!(chain.block_macs.len(), 1); + + let mut buf = [0u8; 1024]; + let len = chain.read(1, &mut buf).unwrap(); + assert_eq!(len, 512); + assert_eq!(buf[..512], [1u8; 512]); + } + + #[test] + fn recover() { + let disk = MemDisk::create(16).unwrap(); + let key = { + let sub_disk = disk.subset(0..8).unwrap(); + let block_ring = BlockRing::new(sub_disk); + block_ring.set_cursor(0); + let data = [1u8; 1024]; + let mut chain = CryptoChain::new(block_ring); + for _ in 0..4 { + chain.append(&data).unwrap(); + } + chain.flush().unwrap(); + chain.key + }; + + let sub_disk = disk.subset(0..8).unwrap(); + let block_ring = BlockRing::new(sub_disk); + let mut recover = CryptoChain::recover(key, block_ring, 2); + while let Some(payload) = recover.next() { + assert_eq!(payload.len(), 1024); + } + let chain = recover.open(); + assert_eq!(chain.block_range(), &(2..4)); + assert_eq!(chain.block_macs.len(), 2); + } +} diff --git a/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_log.rs b/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_log.rs new file mode 100644 index 00000000..b2f46930 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/1-crypto/crypto_log.rs @@ -0,0 +1,1192 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::vec; +use core::{any::Any, mem::size_of}; + +use ostd_pod::Pod; +use serde::{Deserialize, Serialize}; +use static_assertions::const_assert; + +use super::{Iv, Key, Mac}; +use crate::{ + layers::bio::{BlockId, BlockLog, Buf, BufMut, BufRef, BLOCK_SIZE}, + os::{Aead, HashMap, Mutex, RwLock}, + prelude::*, +}; + +/// A cryptographically-protected log of user data blocks. +/// +/// `CryptoLog<L>`, which is backed by an untrusted block log (`L`), +/// serves as a secure log file that supports random reads and append-only +/// writes of data blocks. `CryptoLog<L>` encrypts the data blocks and +/// protects them with a Merkle Hash Tree (MHT), which itself is also encrypted. +/// +/// # Security +/// +/// Each instance of `CryptoLog<L>` is assigned a randomly-generated root key +/// upon its creation. The root key is used to encrypt the root MHT block only. +/// Each new version of the root MHT block is encrypted with the same key, but +/// different random IVs. This arrangement ensures the confidentiality of +/// the root block. +/// +/// After flushing a `CryptoLog<L>`, a new root MHT (as well as other MHT nodes) +/// shall be appended to the backend block log (`L`). +/// The metadata of the root MHT, including its position, encryption +/// key, IV, and MAC, must be kept by the user of `CryptoLog<L>` so that +/// he or she can use the metadata to re-open the `CryptoLog`. +/// The information contained in the metadata is sufficient to verify the +/// integrity and freshness of the root MHT node, and thus the whole `CryptoLog`. +/// +/// Other MHT nodes as well as data nodes are encrypted with randomly-generated, +/// unique keys. Their metadata, including its position, encryption key, IV, and +/// MAC, are kept securely in their parent MHT nodes, which are also encrypted. +/// Thus, the confidentiality and integrity of non-root nodes are protected. +/// +/// # Performance +/// +/// Thanks to its append-only nature, `CryptoLog<L>` avoids MHT's high +/// performance overheads under the workload of random writes +/// due to "cascades of updates". +/// +/// Behind the scene, `CryptoLog<L>` keeps a cache for nodes so that frequently +/// or lately accessed nodes can be found in the cache, avoiding the I/O +/// and decryption cost incurred when re-reading these nodes. +/// The cache is also used for buffering new data so that multiple writes to +/// individual nodes can be merged into a large write to the underlying block log. +/// Therefore, `CryptoLog<L>` is efficient for both reads and writes. +/// +/// # Disk space +/// +/// One consequence of using an append-only block log (`L`) as the backend is +/// that `CryptoLog<L>` cannot do in-place updates to existing MHT nodes. +/// This means the new version of MHT nodes are appended to the underlying block +/// log and the invalid blocks occupied by old versions are not reclaimed. +/// +/// But lucky for us, this block reclamation problem is not an issue in practice. +/// This is because a `CryptoLog<L>` is created for one of the following two +/// use cases. +/// +/// 1. Write-once-then-read-many. In this use case, all the content of a +/// `CryptoLog` is written in a single run. +/// Writing in a single run won't trigger any updates to MHT nodes and thus +/// no waste of disk space. +/// After the writing is done, the `CryptoLog` becomes read-only. +/// +/// 2. Write-many-then-read-once. In this use case, the content of a +/// `CryptoLog` may be written in many runs. But the number of `CryptoLog` +/// under such workloads is limited and their lengths are also limited. +/// So the disk space wasted by such `CryptoLog` is bounded. +/// And after such `CryptoLog`s are done writing, they will be read once and +/// then discarded. +pub struct CryptoLog<L> { + mht: RwLock<Mht<L>>, +} + +type Lbid = BlockId; // Logical block position, in terms of user +type Pbid = BlockId; // Physical block position, in terms of underlying log +type Height = u8; // The height of the MHT + +/// A Merkle-Hash Tree (MHT). +struct Mht<L> { + root: Option<(RootMhtMeta, Arc<MhtNode>)>, + root_key: Key, + data_buf: AppendDataBuf<L>, + storage: Arc<MhtStorage<L>>, +} + +/// Storage medium for MHT, including both in-memory and persistent. +struct MhtStorage<L> { + block_log: L, + node_cache: Arc<dyn NodeCache>, + crypt_buf: Mutex<CryptBuf>, +} + +/// The metadata of the root MHT node of a `CryptoLog`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RootMhtMeta { + pub pos: Pbid, + pub mac: Mac, + pub iv: Iv, +} + +/// The Merkle-Hash Tree (MHT) node (internal). +/// It contains a header for node metadata and a bunch of entries for managing children nodes. +#[repr(C)] +#[derive(Clone, Copy, Pod)] +struct MhtNode { + header: MhtNodeHeader, + entries: [MhtNodeEntry; MHT_NBRANCHES], +} +const_assert!(size_of::<MhtNode>() <= BLOCK_SIZE); + +/// The header contains metadata of the current MHT node. +#[repr(C)] +#[derive(Clone, Copy, Debug, Pod)] +struct MhtNodeHeader { + // The height of the MHT whose root is this node + height: Height, + // The total number of valid data nodes covered by this node + num_data_nodes: u32, + // The number of valid entries (children) within this node + num_valid_entries: u16, +} + +/// The entry of the MHT node, which contains the +/// metadata of the child MHT/data node. +#[repr(C)] +#[derive(Clone, Copy, Debug, Pod)] +struct MhtNodeEntry { + pos: Pbid, + key: Key, + mac: Mac, +} + +// Number of branches of one MHT node. (102 for now) +const MHT_NBRANCHES: usize = (BLOCK_SIZE - size_of::<MhtNodeHeader>()) / size_of::<MhtNodeEntry>(); + +/// The data node (leaf). It contains a block of data. +#[repr(C)] +#[derive(Clone, Copy, Pod)] +struct DataNode([u8; BLOCK_SIZE]); + +/// Builder for MHT. +struct TreeBuilder<'a, L> { + previous_build: Option<PreviousBuild<'a, L>>, + storage: &'a MhtStorage<L>, +} + +/// Builder for one specific level of MHT. +struct LevelBuilder { + level_entries: Vec<MhtNodeEntry>, + total_data_nodes: usize, + height: Height, + previous_incomplete_node: Option<Arc<MhtNode>>, +} + +/// It encloses necessary information of the previous build of MHT. +struct PreviousBuild<'a, L> { + root: Arc<MhtNode>, + height: Height, + // Each level at most have one incomplete node at end + internal_incomplete_nodes: HashMap<Height, Arc<MhtNode>>, + storage: &'a MhtStorage<L>, +} + +/// The node cache used by `CryptoLog`. User-defined node cache +/// can achieve TX-awareness. +pub trait NodeCache: Send + Sync { + /// Gets an owned value from cache corresponding to the position. + fn get(&self, pos: Pbid) -> Option<Arc<dyn Any + Send + Sync>>; + + /// Puts a position-value pair into cache. If the value of that position + /// already exists, updates it and returns the old value. Otherwise, `None` is returned. + fn put( + &self, + pos: Pbid, + value: Arc<dyn Any + Send + Sync>, + ) -> Option<Arc<dyn Any + Send + Sync>>; +} + +/// Context for a search request. +struct SearchCtx<'a> { + pub pos: Lbid, + pub data_buf: BufMut<'a>, + pub offset: usize, + pub num: usize, + pub is_completed: bool, +} + +/// Prepares buffer for node cryption. +struct CryptBuf { + pub plain: Buf, + pub cipher: Buf, +} + +impl<L: BlockLog> CryptoLog<L> { + /// Creates a new `CryptoLog`. + /// + /// A newly-created instance won't occupy any space on the `block_log` + /// until the first flush, which triggers writing the root MHT node. + pub fn new(block_log: L, root_key: Key, node_cache: Arc<dyn NodeCache>) -> Self { + Self { + mht: RwLock::new(Mht::new(block_log, root_key, node_cache)), + } + } + + /// Opens an existing `CryptoLog` backed by a `block_log`. + /// + /// The given key and the metadata of the root MHT are sufficient to + /// load and verify the root node of the `CryptoLog`. + pub fn open( + block_log: L, + root_key: Key, + root_meta: RootMhtMeta, + node_cache: Arc<dyn NodeCache>, + ) -> Result<Self> { + Ok(Self { + mht: RwLock::new(Mht::open(block_log, root_key, root_meta, node_cache)?), + }) + } + + /// Gets the root key. + pub fn root_key(&self) -> Key { + self.mht.read().root_key + } + + /// Gets the metadata of the root MHT node. + /// + /// Returns `None` if there hasn't been any appends or flush. + pub fn root_meta(&self) -> Option<RootMhtMeta> { + self.mht.read().root_meta() + } + + fn root_node(&self) -> Option<Arc<MhtNode>> { + self.mht.read().root_node().cloned() + } + + /// Gets the number of data nodes (blocks). + pub fn nblocks(&self) -> usize { + self.mht.read().total_data_nodes() + } + + /// Reads one or multiple data blocks at a specified position. + pub fn read(&self, pos: Lbid, buf: BufMut) -> Result<()> { + let mut search_ctx = SearchCtx::new(pos, buf); + self.mht.read().search(&mut search_ctx)?; + + debug_assert!(search_ctx.is_completed); + Ok(()) + } + + /// Appends one or multiple data blocks at the end. + pub fn append(&self, buf: BufRef) -> Result<()> { + let data_nodes: Vec<Arc<DataNode>> = buf + .iter() + .map(|block_buf| { + let data_node = { + let mut node = DataNode::new_uninit(); + node.0.copy_from_slice(block_buf.as_slice()); + Arc::new(node) + }; + data_node + }) + .collect(); + + self.mht.write().append_data_nodes(data_nodes) + } + + /// Ensures that all new data are persisted. + /// + /// Each successful flush triggers writing a new version of the root MHT + /// node to the underlying block log. The metadata of the latest root MHT + /// can be obtained via the `root_meta` method. + pub fn flush(&self) -> Result<()> { + self.mht.write().flush() + } + + pub fn display_mht(&self) { + self.mht.read().display(); + } +} + +impl<L: BlockLog> Mht<L> { + // Buffer capacity for appended data nodes. + const APPEND_BUF_CAPACITY: usize = 2048; + + pub fn new(block_log: L, root_key: Key, node_cache: Arc<dyn NodeCache>) -> Self { + let storage = Arc::new(MhtStorage::new(block_log, node_cache)); + let start_pos = 0 as Lbid; + Self { + root: None, + root_key, + data_buf: AppendDataBuf::new(Self::APPEND_BUF_CAPACITY, start_pos, storage.clone()), + storage, + } + } + + pub fn open( + block_log: L, + root_key: Key, + root_meta: RootMhtMeta, + node_cache: Arc<dyn NodeCache>, + ) -> Result<Self> { + let storage = Arc::new(MhtStorage::new(block_log, node_cache)); + let root_node = storage.root_mht_node(&root_key, &root_meta)?; + let start_pos = root_node.num_data_nodes() as Lbid; + Ok(Self { + root: Some((root_meta, root_node)), + root_key, + data_buf: AppendDataBuf::new(Self::APPEND_BUF_CAPACITY, start_pos, storage.clone()), + storage, + }) + } + + pub fn root_meta(&self) -> Option<RootMhtMeta> { + self.root.as_ref().map(|(root_meta, _)| root_meta.clone()) + } + + fn root_node(&self) -> Option<&Arc<MhtNode>> { + self.root.as_ref().map(|(_, root_node)| root_node) + } + + pub fn total_data_nodes(&self) -> usize { + self.data_buf.num_append() + + self + .root + .as_ref() + .map_or(0, |(_, root_node)| root_node.num_data_nodes()) + } + + pub fn search(&self, search_ctx: &mut SearchCtx<'_>) -> Result<()> { + let root_node = self + .root_node() + .ok_or(Error::with_msg(NotFound, "root MHT node not found"))?; + + if search_ctx.pos + search_ctx.num > self.total_data_nodes() { + return_errno_with_msg!(InvalidArgs, "search out of MHT capacity"); + } + + // Search the append data buffer first + self.data_buf.search_data_nodes(search_ctx)?; + if search_ctx.is_completed { + return Ok(()); + } + + // Search the MHT if needed + self.search_hierarchy(vec![root_node.clone()], root_node.height(), search_ctx) + } + + fn search_hierarchy( + &self, + level_targets: Vec<Arc<MhtNode>>, + mut curr_height: Height, + search_ctx: &mut SearchCtx<'_>, + ) -> Result<()> { + debug_assert!( + !level_targets.is_empty() && curr_height == level_targets.first().unwrap().height() + ); + let num_data_nodes = search_ctx.num; + + // Calculate two essential values for searching the current level: + // how many nodes to skip and how many nodes we need + let (nodes_skipped, nodes_needed) = { + let pos = &mut search_ctx.pos; + let next_level_max_num_data_nodes = MhtNode::max_num_data_nodes(curr_height - 1); + let skipped = *pos / next_level_max_num_data_nodes; + *pos -= skipped * next_level_max_num_data_nodes; + let needed = align_up(num_data_nodes + *pos, next_level_max_num_data_nodes) + / next_level_max_num_data_nodes; + (skipped, needed) + }; + + let target_entries = level_targets + .iter() + .flat_map(|node| node.entries.iter()) + .skip(nodes_skipped) + .take(nodes_needed); + + // Search down to the leaves, ready to collect data nodes + if MhtNode::is_lowest_level(curr_height) { + debug_assert_eq!(num_data_nodes, nodes_needed); + for entry in target_entries { + self.storage + .read_data_node(entry, search_ctx.node_buf(search_ctx.offset))?; + search_ctx.offset += 1; + } + search_ctx.is_completed = true; + return Ok(()); + } + + // Prepare target MHT nodes for the lower level + let next_level_targets = { + let mut targets = Vec::with_capacity(nodes_needed); + for entry in target_entries { + let target_node = self.storage.read_mht_node( + entry.pos, + &entry.key, + &entry.mac, + &Iv::new_zeroed(), + )?; + targets.push(target_node); + } + targets + }; + + // Search the lower level + curr_height -= 1; + self.search_hierarchy(next_level_targets, curr_height, search_ctx) + } + + pub fn append_data_nodes(&mut self, data_nodes: Vec<Arc<DataNode>>) -> Result<()> { + self.data_buf.append_data_nodes(data_nodes)?; + if self.data_buf.is_full() { + let data_node_entries = self.data_buf.flush()?; + self.do_build(data_node_entries)?; + } + Ok(()) + } + + pub fn flush(&mut self) -> Result<()> { + let data_node_entries = self.data_buf.flush()?; + self.do_build(data_node_entries)?; + // FIXME: Should we sync the storage here? + // self.storage.flush()?; + Ok(()) + } + + fn do_build(&mut self, data_node_entries: Vec<MhtNodeEntry>) -> Result<()> { + let new_root_node = { + TreeBuilder::new(&self.storage) + .previous_built_root(self.root.as_ref().map(|(_, root_node)| root_node)) + .build(data_node_entries)? + }; + let root_meta = self + .storage + .append_root_mht_node(&self.root_key, &new_root_node)?; + let _ = self.root.insert((root_meta, new_root_node)); + Ok(()) + } + + pub fn display(&self) { + info!("{:?}", MhtDisplayer(self)); + } +} + +impl<L: BlockLog> MhtStorage<L> { + pub fn new(block_log: L, node_cache: Arc<dyn NodeCache>) -> Self { + Self { + block_log, + node_cache, + crypt_buf: Mutex::new(CryptBuf::new()), + } + } + + pub fn flush(&self) -> Result<()> { + self.block_log.flush() + } + + pub fn root_mht_node(&self, root_key: &Key, root_meta: &RootMhtMeta) -> Result<Arc<MhtNode>> { + self.read_mht_node(root_meta.pos, root_key, &root_meta.mac, &root_meta.iv) + } + + pub fn append_root_mht_node(&self, root_key: &Key, node: &Arc<MhtNode>) -> Result<RootMhtMeta> { + let mut crypt_buf = self.crypt_buf.lock(); + let iv = Iv::random(); + let mac = Aead::new().encrypt( + node.as_bytes(), + root_key, + &iv, + &[], + crypt_buf.cipher.as_mut_slice(), + )?; + + let pos = self.block_log.append(crypt_buf.cipher.as_ref())?; + self.node_cache.put(pos, node.clone()); + Ok(RootMhtMeta { pos, mac, iv }) + } + + fn append_mht_nodes(&self, nodes: &[Arc<MhtNode>]) -> Result<Vec<MhtNodeEntry>> { + let num_append = nodes.len(); + let mut node_entries = Vec::with_capacity(num_append); + let mut cipher_buf = Buf::alloc(num_append)?; + let mut pos = self.block_log.nblocks() as BlockId; + let start_pos = pos; + for (i, node) in nodes.iter().enumerate() { + let plain = node.as_bytes(); + let cipher = &mut cipher_buf.as_mut_slice()[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE]; + let key = Key::random(); + let mac = Aead::new().encrypt(plain, &key, &Iv::new_zeroed(), &[], cipher)?; + + node_entries.push(MhtNodeEntry { pos, key, mac }); + self.node_cache.put(pos, node.clone()); + pos += 1; + } + + let append_pos = self.block_log.append(cipher_buf.as_ref())?; + debug_assert_eq!(start_pos, append_pos); + Ok(node_entries) + } + + fn append_data_nodes(&self, nodes: &[Arc<DataNode>]) -> Result<Vec<MhtNodeEntry>> { + let num_append = nodes.len(); + let mut node_entries = Vec::with_capacity(num_append); + if num_append == 0 { + return Ok(node_entries); + } + + let mut cipher_buf = Buf::alloc(num_append)?; + let mut pos = self.block_log.nblocks() as BlockId; + let start_pos = pos; + for (i, node) in nodes.iter().enumerate() { + let cipher = &mut cipher_buf.as_mut_slice()[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE]; + let key = Key::random(); + let mac = Aead::new().encrypt(&node.0, &key, &Iv::new_zeroed(), &[], cipher)?; + + node_entries.push(MhtNodeEntry { pos, key, mac }); + pos += 1; + } + + let append_pos = self.block_log.append(cipher_buf.as_ref())?; + debug_assert_eq!(start_pos, append_pos); + Ok(node_entries) + } + + fn read_mht_node(&self, pos: Pbid, key: &Key, mac: &Mac, iv: &Iv) -> Result<Arc<MhtNode>> { + if let Some(node) = self.node_cache.get(pos) { + return node.downcast::<MhtNode>().map_err(|_| { + Error::with_msg(InvalidArgs, "cache node downcasts to MHT node failed") + }); + } + + let mht_node = { + let mut crypt_buf = self.crypt_buf.lock(); + self.block_log.read(pos, crypt_buf.cipher.as_mut())?; + let mut node = MhtNode::new_zeroed(); + Aead::new().decrypt( + crypt_buf.cipher.as_slice(), + key, + iv, + &[], + mac, + node.as_bytes_mut(), + )?; + crypt_buf + .plain + .as_mut_slice() + .copy_from_slice(node.as_bytes()); + Arc::new(node) + }; + + self.node_cache.put(pos, mht_node.clone()); + Ok(mht_node) + } + + fn read_data_node(&self, entry: &MhtNodeEntry, node_buf: &mut [u8]) -> Result<()> { + debug_assert_eq!(node_buf.len(), BLOCK_SIZE); + let mut crypt_buf = self.crypt_buf.lock(); + + self.block_log.read(entry.pos, crypt_buf.cipher.as_mut())?; + Aead::new().decrypt( + crypt_buf.cipher.as_slice(), + &entry.key, + &Iv::new_zeroed(), + &[], + &entry.mac, + node_buf, + ) + } +} + +impl MhtNode { + pub fn height(&self) -> Height { + self.header.height + } + + pub fn num_data_nodes(&self) -> usize { + self.header.num_data_nodes as _ + } + + pub fn num_valid_entries(&self) -> usize { + self.header.num_valid_entries as _ + } + + // Lowest level MHT node's children are data nodes + pub fn is_lowest_level(height: Height) -> bool { + height == 1 + } + + pub fn max_num_data_nodes(height: Height) -> usize { + // Also correct when height equals 0 + MHT_NBRANCHES.pow(height as _) + } + + // A complete node indicates that all children are valid and + // all covered with maximum number of data nodes + pub fn is_incomplete(&self) -> bool { + self.num_data_nodes() != Self::max_num_data_nodes(self.height()) + } + + pub fn num_complete_children(&self) -> usize { + if self.num_data_nodes() % MHT_NBRANCHES == 0 || Self::is_lowest_level(self.height()) { + self.num_valid_entries() + } else { + self.num_valid_entries() - 1 + } + } +} + +impl<'a, L: BlockLog> TreeBuilder<'a, L> { + pub fn new(storage: &'a MhtStorage<L>) -> Self { + Self { + previous_build: None, + storage, + } + } + + pub fn previous_built_root(mut self, previous_built_root: Option<&Arc<MhtNode>>) -> Self { + if previous_built_root.is_none() { + return self; + } + self.previous_build = Some(PreviousBuild::new( + previous_built_root.unwrap(), + self.storage, + )); + self + } + + pub fn build(&self, data_node_entries: Vec<MhtNodeEntry>) -> Result<Arc<MhtNode>> { + let total_data_nodes = data_node_entries.len() + + self + .previous_build + .as_ref() + .map_or(0, |pre| pre.num_data_nodes()); + + self.build_hierarchy( + data_node_entries, + total_data_nodes, + 1 as Height, + self.calc_target_height(total_data_nodes), + ) + } + + fn build_hierarchy( + &self, + level_entries: Vec<MhtNodeEntry>, + total_data_nodes: usize, + mut curr_height: Height, + target_height: Height, + ) -> Result<Arc<MhtNode>> { + // Build the MHT nodes of current level + let mut new_mht_nodes = { + // Previous built incomplete node of same level should participate in the building + let previous_incomplete_node = self + .previous_build + .as_ref() + .and_then(|pre| pre.target_node(curr_height)); + + LevelBuilder::new(level_entries, total_data_nodes, curr_height) + .previous_incomplete_node(previous_incomplete_node) + .build() + }; + + if curr_height == target_height { + // The root MHT node has been built + debug_assert_eq!(new_mht_nodes.len(), 1); + return Ok(new_mht_nodes.pop().unwrap()); + } + + // Prepare MHT node entries for the higher level + let next_level_entries = self.storage.append_mht_nodes(&new_mht_nodes)?; + // Build the higher level + curr_height += 1; + self.build_hierarchy( + next_level_entries, + total_data_nodes, + curr_height, + target_height, + ) + } + + fn calc_target_height(&self, num_data_node_entries: usize) -> Height { + let target_height = num_data_node_entries.ilog(MHT_NBRANCHES); + if MHT_NBRANCHES.pow(target_height) < num_data_node_entries || target_height == 0 { + (target_height + 1) as Height + } else { + target_height as Height + } + } +} + +impl LevelBuilder { + pub fn new(level_entries: Vec<MhtNodeEntry>, total_data_nodes: usize, height: Height) -> Self { + Self { + level_entries, + total_data_nodes, + height, + previous_incomplete_node: None, + } + } + + pub fn previous_incomplete_node( + mut self, + previous_incomplete_node: Option<Arc<MhtNode>>, + ) -> Self { + self.previous_incomplete_node = previous_incomplete_node; + self + } + + pub fn build(&self) -> Vec<Arc<MhtNode>> { + let all_level_entries: Vec<&MhtNodeEntry> = + if let Some(pre_node) = self.previous_incomplete_node.as_ref() { + // If there exists a previous built node (same height), + // its complete entries should participate in the building + pre_node + .entries + .iter() + .take(pre_node.num_complete_children()) + .chain(self.level_entries.iter()) + .collect() + } else { + self.level_entries.iter().collect() + }; + + let num_build = align_up(all_level_entries.len(), MHT_NBRANCHES) / MHT_NBRANCHES; + let mut new_mht_nodes = Vec::with_capacity(num_build); + // Each iteration builds a new MHT node + for (nth, entries_per_node) in all_level_entries.chunks(MHT_NBRANCHES).enumerate() { + if nth == num_build - 1 { + let last_new_node = self.build_last_node(entries_per_node); + new_mht_nodes.push(last_new_node); + break; + } + + let mut mht_node = MhtNode::new_zeroed(); + mht_node.header = MhtNodeHeader { + height: self.height, + num_data_nodes: MhtNode::max_num_data_nodes(self.height) as _, + num_valid_entries: MHT_NBRANCHES as _, + }; + for (i, entry) in mht_node.entries.iter_mut().enumerate() { + *entry = *entries_per_node[i]; + } + + new_mht_nodes.push(Arc::new(mht_node)); + } + new_mht_nodes + } + + // Building last MHT node of the level can be complicated, since + // the last node may be incomplete + fn build_last_node(&self, entries: &[&MhtNodeEntry]) -> Arc<MhtNode> { + let num_data_nodes = { + let max_data_nodes = MhtNode::max_num_data_nodes(self.height); + let incomplete_nodes = self.total_data_nodes % max_data_nodes; + if incomplete_nodes == 0 { + max_data_nodes + } else { + incomplete_nodes + } + }; + let num_valid_entries = entries.len(); + + let mut last_mht_node = MhtNode::new_zeroed(); + last_mht_node.header = MhtNodeHeader { + height: self.height, + num_data_nodes: num_data_nodes as _, + num_valid_entries: num_valid_entries as _, + }; + for (i, entry) in last_mht_node.entries.iter_mut().enumerate() { + *entry = if i < num_valid_entries { + *entries[i] + } else { + // Padding invalid entries to the rest + MhtNodeEntry::new_uninit() + }; + } + + Arc::new(last_mht_node) + } +} + +impl<'a, L: BlockLog> PreviousBuild<'a, L> { + pub fn new(previous_built_root: &Arc<MhtNode>, storage: &'a MhtStorage<L>) -> Self { + let mut new_self = Self { + root: previous_built_root.clone(), + height: previous_built_root.height(), + internal_incomplete_nodes: HashMap::new(), + storage, + }; + new_self.collect_incomplete_nodes(); + new_self + } + + pub fn target_node(&self, target_height: Height) -> Option<Arc<MhtNode>> { + if target_height == self.height { + return Some(self.root.clone()); + } + self.internal_incomplete_nodes.get(&target_height).cloned() + } + + pub fn num_data_nodes(&self) -> usize { + self.root.num_data_nodes() + } + + fn collect_incomplete_nodes(&mut self) { + let root_node = &self.root; + if !root_node.is_incomplete() || MhtNode::is_lowest_level(self.height) { + return; + } + + let mut lookup_node = { + let entry = root_node.entries[root_node.num_valid_entries() - 1]; + self.storage + .read_mht_node(entry.pos, &entry.key, &entry.mac, &Iv::new_zeroed()) + .unwrap() + }; + + while lookup_node.is_incomplete() { + let height = lookup_node.height(); + self.internal_incomplete_nodes + .insert(height, lookup_node.clone()); + + if MhtNode::is_lowest_level(height) { + break; + } + + // Incomplete nodes only appear in the last node of each level + lookup_node = { + let entry = lookup_node.entries[lookup_node.num_valid_entries() - 1]; + self.storage + .read_mht_node(entry.pos, &entry.key, &entry.mac, &Iv::new_zeroed()) + .unwrap() + } + } + } +} + +impl<'a> SearchCtx<'a> { + pub fn new(pos: Lbid, data_buf: BufMut<'a>) -> Self { + let num = data_buf.nblocks(); + Self { + pos, + data_buf, + offset: 0, + num, + is_completed: false, + } + } + + pub fn node_buf(&mut self, offset: usize) -> &mut [u8] { + &mut self.data_buf.as_mut_slice()[offset * BLOCK_SIZE..(offset + 1) * BLOCK_SIZE] + } +} + +impl CryptBuf { + pub fn new() -> Self { + Self { + plain: Buf::alloc(1).unwrap(), + cipher: Buf::alloc(1).unwrap(), + } + } +} + +/// A buffer that contains appended data. +struct AppendDataBuf<L> { + node_queue: Vec<Arc<DataNode>>, + node_queue_cap: usize, + entry_queue: Vec<MhtNodeEntry>, // Also cache the data node entries + entry_queue_cap: usize, + start_pos: Lbid, + storage: Arc<MhtStorage<L>>, +} + +impl<L: BlockLog> AppendDataBuf<L> { + // Maximum capacity of entries indicates a complete MHT (height equals 3) + const MAX_ENTRY_QUEUE_CAP: usize = MHT_NBRANCHES.pow(3); + + pub fn new(capacity: usize, start_pos: Lbid, storage: Arc<MhtStorage<L>>) -> Self { + let (node_queue_cap, entry_queue_cap) = Self::calc_queue_cap(capacity, start_pos); + Self { + node_queue: Vec::with_capacity(node_queue_cap), + node_queue_cap, + entry_queue: Vec::with_capacity(entry_queue_cap), + start_pos, + entry_queue_cap, + storage, + } + } + + pub fn num_append(&self) -> usize { + self.node_queue.len() + self.entry_queue.len() + } + + pub fn is_full(&self) -> bool { + // Returns whether the data node entry queue is at capacity + self.entry_queue.len() >= self.entry_queue_cap + } + + pub fn append_data_nodes(&mut self, nodes: Vec<Arc<DataNode>>) -> Result<()> { + if self.is_full() { + return_errno_with_msg!(OutOfMemory, "cache out of capacity"); + } + + self.node_queue.extend(nodes); + if self.node_queue.len() >= self.node_queue_cap { + // If node queue is full, flush nodes to the entry queue + self.flush_node_queue()?; + } + Ok(()) + } + + pub fn search_data_nodes(&self, search_ctx: &mut SearchCtx) -> Result<()> { + let start_pos = self.start_pos; + let (pos, num) = (search_ctx.pos, search_ctx.num); + if pos + num <= start_pos { + return Ok(()); + } + + let (mut start_nth, mut end_nth, mut offset) = if pos >= start_pos { + let start = pos - start_pos; + (start, start + num, 0) + } else { + let end = pos + num - start_pos; + let offset = search_ctx.num - end; + search_ctx.num -= end; + (0, end, offset) + }; + debug_assert!(end_nth <= self.num_append()); + + // Read from entry queue first if needed + for entry in self + .entry_queue + .iter() + .skip(start_nth) + .take(end_nth - start_nth) + { + self.storage + .read_data_node(entry, search_ctx.node_buf(offset))?; + start_nth = 0; + end_nth -= 1; + offset += 1; + } + + // Read from node queue if needed + for node in self + .node_queue + .iter() + .skip(start_nth) + .take(end_nth - start_nth) + { + let node_buf = search_ctx.node_buf(offset); + node_buf.copy_from_slice(&node.0); + offset += 1; + } + + if pos >= start_pos { + search_ctx.is_completed = true; + } + Ok(()) + } + + pub fn flush(&mut self) -> Result<Vec<MhtNodeEntry>> { + self.flush_node_queue()?; + debug_assert!(self.node_queue.is_empty()); + + let all_cached_entries: Vec<MhtNodeEntry> = self.entry_queue.drain(..).collect(); + self.start_pos += all_cached_entries.len(); + Ok(all_cached_entries) + } + + fn flush_node_queue(&mut self) -> Result<()> { + let new_node_entries = self.storage.append_data_nodes(&self.node_queue)?; + self.entry_queue.extend_from_slice(&new_node_entries); + self.node_queue.clear(); + Ok(()) + } + + fn calc_queue_cap(capacity: usize, append_pos: Lbid) -> (usize, usize) { + // Half for data nodes, half for data node entries + let node_queue_cap = capacity / 2; + let entry_queue_cap = { + let max_cap = Self::MAX_ENTRY_QUEUE_CAP - append_pos; + let remain_cap = (capacity - node_queue_cap) * BLOCK_SIZE / size_of::<MhtNodeEntry>(); + max_cap.min(remain_cap) + }; + (node_queue_cap, entry_queue_cap) + } +} + +impl<L: BlockLog> Debug for CryptoLog<L> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CryptoLog") + .field("mht", &self.mht.read()) + .finish() + } +} + +impl<L: BlockLog> Debug for Mht<L> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Mht") + .field("root_meta", &self.root_meta()) + .field("root_node", &self.root_node()) + .field("root_key", &self.root_key) + .field("total_data_nodes", &self.total_data_nodes()) + .field("buffered_data_nodes", &self.data_buf.num_append()) + .finish() + } +} + +impl Debug for MhtNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MhtNode") + .field("header", &self.header) + .finish() + } +} + +impl Debug for DataNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DataNode") + .field("first 16 bytes", &&self.0[..16]) + .finish() + } +} + +struct MhtDisplayer<'a, L>(&'a Mht<L>); + +impl<L: BlockLog> Debug for MhtDisplayer<'_, L> { + // A heavy implementation to display the whole MHT. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut debug_struct = f.debug_struct("Mht"); + + // Display root MHT node + let root_meta = self.0.root_meta(); + debug_struct.field("\nroot_meta", &root_meta); + if root_meta.is_none() { + return debug_struct.finish(); + } + let root_mht_node = self.0.root_node().unwrap(); + debug_struct.field("\n-> root_mht_node", &root_mht_node); + let mut height = root_mht_node.height(); + if MhtNode::is_lowest_level(height) { + return debug_struct.finish(); + } + + // Display internal MHT nodes hierarchically + let mut level_entries: Vec<MhtNodeEntry> = root_mht_node + .entries + .into_iter() + .take(root_mht_node.num_valid_entries()) + .collect(); + 'outer: loop { + let level_size = level_entries.len(); + for i in 0..level_size { + let entry = &level_entries[i]; + let node = self + .0 + .storage + .read_mht_node(entry.pos, &entry.key, &entry.mac, &Iv::new_zeroed()) + .unwrap(); + debug_struct.field("\n node_entry", entry); + debug_struct.field("\n -> mht_node", &node); + for i in 0..node.num_valid_entries() { + level_entries.push(node.entries[i]); + } + } + level_entries.drain(..level_size); + height -= 1; + if MhtNode::is_lowest_level(height) { + break 'outer; + } + } + debug_struct.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::layers::bio::MemLog; + + struct NoCache; + impl NodeCache for NoCache { + fn get(&self, _pos: Pbid) -> Option<Arc<dyn Any + Send + Sync>> { + None + } + fn put( + &self, + _pos: Pbid, + _value: Arc<dyn Any + Send + Sync>, + ) -> Option<Arc<dyn Any + Send + Sync>> { + None + } + } + + fn create_crypto_log() -> Result<CryptoLog<MemLog>> { + let mem_log = MemLog::create(64 * 1024)?; + let key = Key::random(); + let cache = Arc::new(NoCache {}); + Ok(CryptoLog::new(mem_log, key, cache)) + } + + #[test] + fn crypto_log_fns() -> Result<()> { + let log = create_crypto_log()?; + let append_cnt = MHT_NBRANCHES - 1; + let mut buf = Buf::alloc(1)?; + for i in 0..append_cnt { + buf.as_mut_slice().fill(i as _); + log.append(buf.as_ref())?; + } + log.flush()?; + println!("{:?}", log); + log.display_mht(); + + let content = 5u8; + buf.as_mut_slice().fill(content); + log.append(buf.as_ref())?; + log.flush()?; + log.display_mht(); + log.append(buf.as_ref())?; + log.flush()?; + log.display_mht(); + + let (root_meta, root_node) = (log.root_meta().unwrap(), log.root_node().unwrap()); + assert_eq!(root_meta.pos, 107); + assert_eq!(root_node.height(), 2); + assert_eq!(root_node.num_data_nodes(), append_cnt + 2); + assert_eq!(root_node.num_valid_entries(), 2); + + log.read(5 as BlockId, buf.as_mut())?; + assert_eq!(buf.as_slice(), &[content; BLOCK_SIZE]); + let mut buf = Buf::alloc(2)?; + log.read((MHT_NBRANCHES - 1) as BlockId, buf.as_mut())?; + assert_eq!(buf.as_slice(), &[content; 2 * BLOCK_SIZE]); + Ok(()) + } + + #[test] + fn write_once_read_many() -> Result<()> { + let log = create_crypto_log()?; + let append_cnt = MHT_NBRANCHES * MHT_NBRANCHES; + let batch_cnt = 4; + let mut buf = Buf::alloc(batch_cnt)?; + + for i in 0..(append_cnt / batch_cnt) { + buf.as_mut_slice().fill(i as _); + log.append(buf.as_ref())?; + } + log.flush()?; + log.display_mht(); + + for i in (0..append_cnt).step_by(batch_cnt) { + log.read(i as Lbid, buf.as_mut())?; + assert_eq!(&buf.as_slice()[..128], &[(i / batch_cnt) as u8; 128]); + } + Ok(()) + } + + #[test] + fn write_many_read_once() -> Result<()> { + let log = create_crypto_log()?; + let append_cnt = 2048; + let flush_freq = 125; + let mut buf = Buf::alloc(1)?; + + for i in 0..append_cnt { + buf.as_mut_slice().fill(i as _); + log.append(buf.as_ref())?; + if i % flush_freq == 0 { + log.flush()?; + } + } + log.flush()?; + log.display_mht(); + + for i in (0..append_cnt).rev() { + log.read(i as Lbid, buf.as_mut())?; + assert_eq!(&buf.as_slice()[2048..], &[i as u8; 2048]); + } + Ok(()) + } +} diff --git a/kernel/comps/mlsdisk/src/layers/1-crypto/mod.rs b/kernel/comps/mlsdisk/src/layers/1-crypto/mod.rs new file mode 100644 index 00000000..ce6cff78 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/1-crypto/mod.rs @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! The layer of cryptographical constructs. + +mod crypto_blob; +mod crypto_chain; +mod crypto_log; + +pub use self::{ + crypto_blob::CryptoBlob, + crypto_chain::CryptoChain, + crypto_log::{CryptoLog, NodeCache, RootMhtMeta}, +}; + +pub type Key = crate::os::AeadKey; +pub type Iv = crate::os::AeadIv; +pub type Mac = crate::os::AeadMac; +pub type VersionId = u64; diff --git a/kernel/comps/mlsdisk/src/layers/2-edit/edits.rs b/kernel/comps/mlsdisk/src/layers/2-edit/edits.rs new file mode 100644 index 00000000..18ec2d98 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/2-edit/edits.rs @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::marker::PhantomData; + +use serde::{ser::SerializeSeq, Deserialize, Serialize}; + +use crate::prelude::*; + +/// An edit of `Edit<S>` is an incremental change to a state of `S`. +pub trait Edit<S>: Serialize + for<'de> Deserialize<'de> { + /// Apply this edit to a state. + fn apply_to(&self, state: &mut S); +} + +/// A group of edits to a state. +pub struct EditGroup<E: Edit<S>, S> { + edits: Vec<E>, + _s: PhantomData<S>, +} + +impl<E: Edit<S>, S> EditGroup<E, S> { + /// Creates an empty edit group. + pub fn new() -> Self { + Self { + edits: Vec::new(), + _s: PhantomData, + } + } + + /// Adds an edit to the group. + pub fn push(&mut self, edit: E) { + self.edits.push(edit); + } + + /// Returns an iterator to the contained edits. + pub fn iter(&self) -> impl Iterator<Item = &E> { + self.edits.iter() + } + + /// Clears the edit group by removing all contained edits. + pub fn clear(&mut self) { + self.edits.clear() + } + + /// Returns whether the edit group contains no edits. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the length of the edit group. + pub fn len(&self) -> usize { + self.edits.len() + } +} + +impl<E: Edit<S>, S> Edit<S> for EditGroup<E, S> { + fn apply_to(&self, state: &mut S) { + for edit in &self.edits { + edit.apply_to(state); + } + } +} + +impl<E: Edit<S>, S> Serialize for EditGroup<E, S> { + fn serialize<Se>(&self, serializer: Se) -> core::result::Result<Se::Ok, Se::Error> + where + Se: serde::Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.len()))?; + for edit in &self.edits { + seq.serialize_element(edit)? + } + seq.end() + } +} + +impl<'de, E: Edit<S>, S> Deserialize<'de> for EditGroup<E, S> { + fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + struct EditsVisitor<E: Edit<S>, S> { + _p: PhantomData<(E, S)>, + } + + impl<'a, E: Edit<S>, S> serde::de::Visitor<'a> for EditsVisitor<E, S> { + type Value = EditGroup<E, S>; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + formatter.write_str("an edit group") + } + + fn visit_seq<A>(self, mut seq: A) -> core::result::Result<Self::Value, A::Error> + where + A: serde::de::SeqAccess<'a>, + { + let mut edits = Vec::with_capacity(seq.size_hint().unwrap_or(0)); + while let Some(e) = seq.next_element()? { + edits.push(e); + } + Ok(EditGroup { + edits, + _s: PhantomData, + }) + } + } + + deserializer.deserialize_seq(EditsVisitor { _p: PhantomData }) + } +} + +#[cfg(test)] +mod tests { + use serde::{Deserialize, Serialize}; + + use super::*; + + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct XEdit { + x: i32, + } + + struct XState { + sum: i32, + } + + impl Edit<XState> for XEdit { + fn apply_to(&self, state: &mut XState) { + (*state).sum += self.x; + } + } + + #[test] + fn serde_edit() { + let mut group = EditGroup::<XEdit, XState>::new(); + let mut sum = 0; + for x in 0..10 { + sum += x; + let edit = XEdit { x }; + group.push(edit); + } + let mut state = XState { sum: 0 }; + group.apply_to(&mut state); + assert_eq!(state.sum, sum); + + let mut buf = [0u8; 64]; + let ser = postcard::to_slice(&group, buf.as_mut_slice()).unwrap(); + println!("serialize len: {} data: {:?}", ser.len(), ser); + let de: EditGroup<XEdit, XState> = postcard::from_bytes(buf.as_slice()).unwrap(); + println!("deserialize edits: {:?}", de.edits); + assert_eq!(de.len(), group.len()); + assert_eq!(de.edits.as_slice(), group.edits.as_slice()); + } +} diff --git a/kernel/comps/mlsdisk/src/layers/2-edit/journal.rs b/kernel/comps/mlsdisk/src/layers/2-edit/journal.rs new file mode 100644 index 00000000..38d58a17 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/2-edit/journal.rs @@ -0,0 +1,1081 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::marker::PhantomData; + +use lending_iterator::LendingIterator; +use ostd_pod::Pod; +use serde::{ + de::{VariantAccess, Visitor}, + Deserialize, Serialize, +}; + +use super::{Edit, EditGroup}; +use crate::{ + layers::{ + bio::{BlockRing, BlockSet, Buf}, + crypto::{CryptoBlob, CryptoChain, Key, Mac}, + }, + prelude::*, +}; + +/// The journal of a series of edits to a persistent state. +/// +/// `EditJournal` is designed to cater the needs of a usage scenario +/// where a persistent state is updated with incremental changes and in high +/// frequency. Apparently, writing the latest value of the +/// state to disk upon every update would result in a poor performance. +/// So instead `EditJournal` keeps a journal of these incremental updates, +/// which are called _edits_. Collectively, these edits can represent the latest +/// value of the state. Edits are persisted in batch, thus the write performance +/// is superior. +/// Behind the scene, `EditJournal` leverages a `CryptoChain` to store the edit +/// journal securely. +/// +/// # Compaction +/// +/// As the total number of edits amounts over time, so does the total size of +/// the storage space consumed by the edit journal. To keep the storage +/// consumption at bay, accumulated edits are merged into one snapshot periodically, +/// This process is called compaction. +/// The snapshot is stored in a location independent from the journal, +/// using `CryptoBlob` for security. The MAC of the snapshot is stored in the +/// journal. Each `EditJournal` keeps two copies of the snapshots so that even +/// one of them is corrupted due to unexpected crashes, the other is still valid. +/// +/// # Atomicity +/// +/// Edits are added to an edit journal individually with the `add` method but +/// are committed to the journal atomically via the `commit` method. This is +/// done by buffering newly-added edits into an edit group, which is called +/// the _current edit group_. Upon commit, the current edit group is persisted +/// to disk as a whole. It is guaranteed that the recovery process shall never +/// recover a partial edit group. +pub struct EditJournal< + E: Edit<S>, /* Edit */ + S, /* State */ + D, /* BlockSet */ + P, /* Policy */ +> { + state: S, + journal_chain: CryptoChain<BlockRing<D>>, + snapshots: SnapshotManager<S, D>, + compaction_policy: P, + curr_edit_group: Option<EditGroup<E, S>>, + write_buf: WriteBuf<E, S>, +} + +/// The metadata of an edit journal. +/// +/// The metadata is mainly useful when recovering an edit journal after a reboot. +#[repr(C)] +#[derive(Clone, Copy, Pod, Debug)] +pub struct EditJournalMeta { + /// The number of blocks reserved for storing a snapshot `CryptoBlob`. + pub snapshot_area_nblocks: usize, + /// The key of a snapshot `CryptoBlob`. + pub snapshot_area_keys: [Key; 2], + /// The number of blocks reserved for storing the journal `CryptoChain`. + pub journal_area_nblocks: usize, + /// The key of the `CryptoChain`. + pub journal_area_key: Key, +} + +impl EditJournalMeta { + /// Returns the total number of blocks occupied by the edit journal. + pub fn total_nblocks(&self) -> usize { + self.snapshot_area_nblocks * 2 + self.journal_area_nblocks + } +} + +impl<E, S, D, P> EditJournal<E, S, D, P> +where + E: Edit<S>, + S: Serialize + for<'de> Deserialize<'de> + Clone, + D: BlockSet, + P: CompactPolicy<E, S>, +{ + /// Format the disk for storing an edit journal with the specified + /// configurations, e.g., the initial state. + pub fn format( + disk: D, + init_state: S, + state_max_nbytes: usize, + mut compaction_policy: P, + ) -> Result<EditJournal<E, S, D, P>> { + // Create `SnapshotManager` to persist the init state. + let snapshots = SnapshotManager::create(&disk, &init_state, state_max_nbytes)?; + + // Create an empty `CryptoChain`. + let mut journal_chain = { + let chain_set = disk.subset(snapshots.nblocks() * 2..disk.nblocks())?; + let block_ring = BlockRing::new(chain_set); + block_ring.set_cursor(0); + CryptoChain::new(block_ring) + }; + + // Persist the MAC of latest snapshot to `CryptoChain`. + let mac = snapshots.latest_mac(); + let mut write_buf = WriteBuf::new(CryptoChain::<BlockRing<D>>::AVAIL_BLOCK_SIZE); + write_buf.write(&Record::Version(mac))?; + journal_chain.append(write_buf.as_slice())?; + compaction_policy.on_append_journal(1); + write_buf.clear(); + journal_chain.flush()?; + + Ok(Self { + state: init_state, + journal_chain, + snapshots, + compaction_policy, + curr_edit_group: Some(EditGroup::new()), + write_buf, + }) + } + + /// Recover an existing edit journal from the disk with the given + /// configurations. + /// + /// If the recovery process succeeds, the edit journal is returned + /// and the state represented by the edit journal can be obtained + /// via the `state` method. + pub fn recover(disk: D, meta: &EditJournalMeta, compaction: P) -> Result<Self> { + // Recover `SnapshotManager`. + let snapshots = SnapshotManager::<S, D>::recover(&disk, meta)?; + let latest_snapshot_mac = snapshots.latest_mac(); + let latest_snapshot = snapshots.latest_snapshot(); + let mut state = latest_snapshot.state.clone(); + let recover_from = latest_snapshot.recover_from; + + // Recover `CryptoChain`. + let snapshot_area_offset = meta.snapshot_area_nblocks * 2; + let block_log = + disk.subset(snapshot_area_offset..snapshot_area_offset + meta.journal_area_nblocks)?; + let block_ring = BlockRing::new(block_log); + let mut recover = CryptoChain::recover(meta.journal_area_key, block_ring, recover_from); + + // Apply `EditGroup` found in `Recovery`. + let mut should_apply = false; + while let Some(buf) = recover.next() { + let record_slice = RecordSlice::<E, S>::new(buf); + for record in record_slice { + match record { + // After each compaction, the first record should always be + // `Record::Version`, storing the MAC of latest_snapshot. + Record::Version(snapshot_mac) => { + if snapshot_mac.as_bytes() == latest_snapshot_mac.as_bytes() { + should_apply = true; + } + } + Record::Edit(group) => { + if should_apply { + group.apply_to(&mut state); + } + } + } + } + } + + // Set new_cursor of `CryptoChain`, so that new record could be appended + // right after the recovered records. + let journal_chain = recover.open(); + let new_cursor = journal_chain.block_range().end; + journal_chain.inner_log().set_cursor(new_cursor); + + Ok(Self { + state, + journal_chain, + snapshots, + compaction_policy: compaction, + curr_edit_group: Some(EditGroup::new()), + write_buf: WriteBuf::new(CryptoChain::<BlockRing<D>>::AVAIL_BLOCK_SIZE), + }) + } + + /// Returns the state represented by the journal. + pub fn state(&self) -> &S { + &self.state + } + + /// Returns the metadata of the edit journal. + pub fn meta(&self) -> EditJournalMeta { + EditJournalMeta { + snapshot_area_nblocks: self.snapshots.nblocks(), + snapshot_area_keys: self.snapshots.keys(), + journal_area_nblocks: self.journal_chain.inner_log().storage().nblocks(), + journal_area_key: *self.journal_chain.key(), + } + } + + /// Add an edit to the current edit group. + pub fn add(&mut self, edit: E) { + let edit_group = self.curr_edit_group.get_or_insert_with(|| EditGroup::new()); + edit_group.push(edit); + } + + /// Commit the current edit group. + pub fn commit(&mut self) { + let Some(edit_group) = self.curr_edit_group.take() else { + return; + }; + if edit_group.is_empty() { + return; + } + + let record = Record::Edit(edit_group); + self.write(&record); + let edit_group = match record { + Record::Edit(edit_group) => edit_group, + _ => unreachable!(), + }; + edit_group.apply_to(&mut self.state); + self.compaction_policy.on_commit_edits(&edit_group); + } + + fn write(&mut self, record: &Record<E, S>) { + // XXX: the serialized record should be less than write_buf. + let is_first_try_success = self.write_buf.write(record).is_ok(); + if is_first_try_success { + return; + } + + // TODO: sync disk first to ensure data are persisted before + // journal records. + + self.append_write_buf_to_journal(); + + let is_second_try_success = self.write_buf.write(record).is_ok(); + if !is_second_try_success { + panic!("the write buffer must have enough free space"); + } + } + + fn append_write_buf_to_journal(&mut self) { + let write_data = self.write_buf.as_slice(); + if write_data.is_empty() { + return; + } + + self.journal_chain + .append(write_data) + // TODO: how to handle I/O error in journaling? + .expect("we cannot handle I/O error in journaling gracefully"); + self.compaction_policy.on_append_journal(1); + self.write_buf.clear(); + + if self.compaction_policy.should_compact() { + // TODO: how to handle a compaction failure? + let compacted_blocks = self.compact().expect("journal chain compaction failed"); + self.compaction_policy.done_compact(compacted_blocks); + } + } + + /// Ensure that all committed edits are persisted to disk. + pub fn flush(&mut self) -> Result<()> { + self.append_write_buf_to_journal(); + self.journal_chain.flush() + } + + /// Abort the current edit group by removing all its contained edits. + pub fn abort(&mut self) { + if let Some(edits) = self.curr_edit_group.as_mut() { + edits.clear(); + } + } + + fn compact(&mut self) -> Result<usize> { + if self.journal_chain.block_range().is_empty() { + return Ok(0); + } + + // Persist current state to latest snapshot. + let latest_snapshot = + Snapshot::create(self.state().clone(), self.journal_chain.block_range().end); + self.snapshots.persist(latest_snapshot)?; + + // Persist the MAC of latest_snapshot. + let mac = self.snapshots.latest_mac(); + self.write_buf.write(&Record::Version(mac))?; + self.journal_chain.append(self.write_buf.as_slice())?; + self.compaction_policy.on_append_journal(1); + self.write_buf.clear(); + + // The latest_snapshot has been persisted, now trim the journal_chain. + // And ensure that there is at least one valid block after trimming. + let old_chain_len = self.journal_chain.block_range().len(); + if old_chain_len > 1 { + self.journal_chain + .trim(self.journal_chain.block_range().end - 1); + } + let new_chain_len = self.journal_chain.block_range().len(); + Ok(old_chain_len - new_chain_len) + } +} + +/// The snapshot to be stored in a `CryptoBlob`, including the persistent state +/// and some metadata. +#[derive(Serialize, Deserialize, Clone)] +struct Snapshot<S> { + state: S, + recover_from: BlockId, +} + +impl<S> Snapshot<S> { + /// Create a new snapshot. + pub fn create(state: S, recover_from: BlockId) -> Arc<Self> { + Arc::new(Self { + state, + recover_from, + }) + } + + /// Return the length of metadata. + pub fn meta_len() -> usize { + core::mem::size_of::<BlockId>() + } +} + +/// The snapshot manager. +/// +/// It keeps two copies of `CryptoBlob`, so that even one of them is corrupted +/// due to unexpected crashes, the other is still valid. +/// +/// The `latest_index` indicates which `CryptoBlob` keeps the latest snapshot. +/// When `persist` a new snapshot, we always choose the older `CryptoBlob` to write, +/// then switch the `latest_index`. And the `VersionId` of two `CryptoBlob`s +/// should be the same or differ by one, since they both start from zero when `create`. +struct SnapshotManager<S, D> { + blobs: [CryptoBlob<D>; 2], + latest_index: usize, + buf: Buf, + snapshot: Arc<Snapshot<S>>, +} + +impl<S, D> SnapshotManager<S, D> +where + S: Serialize + for<'de> Deserialize<'de> + Clone, + D: BlockSet, +{ + /// Consider `DEFAULT_LATEST_INDEX` as the `latest_index`, if the `VersionId` + /// of two `CryptoBlob` are the same, i.e., + /// 1) when `create` a `SnapshotManager`, both `VersionId` are initialized to zero; + /// 2) when `recover` a `SnapshotManager`, one `CryptoBlob` may `recover_from` another, + /// so that their `VersionId` would be the same. + /// + /// This value should only be `0` or `1`. + const DEFAULT_LATEST_INDEX: usize = 0; + + /// Creates a new `SnapshotManager` with specified configurations. + pub fn create(disk: &D, init_state: &S, state_max_nbytes: usize) -> Result<Self> { + // Calculate the minimal blocks needed by `CryptoBlob`, in order to + // store a snapshot (state + metadata). + let blob_bytes = + CryptoBlob::<D>::HEADER_NBYTES + state_max_nbytes + Snapshot::<D>::meta_len(); + let blob_blocks = blob_bytes.div_ceil(BLOCK_SIZE); + if 2 * blob_blocks >= disk.nblocks() { + return_errno_with_msg!(OutOfDisk, "the block_set for journal is too small"); + }; + let mut buf = Buf::alloc(blob_blocks)?; + + // Serialize snapshot (state + metadata). + let snapshot = Snapshot::create(init_state.clone(), 0); + let serialized = postcard::to_slice(snapshot.as_ref(), buf.as_mut_slice()) + .map_err(|_| Error::with_msg(OutOfDisk, "serialize snapshot failed"))?; + + // Persist snapshot to `CryptoBlob`. + let block_set0 = disk.subset(0..blob_blocks)?; + let block_set1 = disk.subset(blob_blocks..blob_blocks * 2)?; + let blobs = [ + CryptoBlob::create(block_set0, serialized)?, + CryptoBlob::create(block_set1, serialized)?, + ]; + Ok(Self { + blobs, + latest_index: Self::DEFAULT_LATEST_INDEX, + buf, + snapshot, + }) + } + + /// Try to recover old `SnapshotManager` with specified disk and metadata. + pub fn recover(disk: &D, meta: &EditJournalMeta) -> Result<Self> { + // Open two CryptoBlob. + let mut blob0 = CryptoBlob::open( + meta.snapshot_area_keys[0], + disk.subset(0..meta.snapshot_area_nblocks)?, + ); + let mut blob1 = CryptoBlob::open( + meta.snapshot_area_keys[1], + disk.subset(meta.snapshot_area_nblocks..meta.snapshot_area_nblocks * 2)?, + ); + + // Try to read the snapshot stored in `CryptoBlob`. + let mut buf = Buf::alloc(meta.snapshot_area_nblocks)?; + let snapshot0_res = match blob0.read(buf.as_mut_slice()) { + Ok(snapshot_len) => { + postcard::from_bytes::<Snapshot<S>>(&buf.as_slice()[..snapshot_len]) + .map_err(|_| Error::with_msg(OutOfDisk, "deserialize snapshot0 failed")) + .map(Arc::new) + } + Err(_) => Err(Error::with_msg(NotFound, "failed to read snapshot0")), + }; + let snapshot1_res = match blob1.read(buf.as_mut_slice()) { + Ok(snapshot_len) => { + postcard::from_bytes::<Snapshot<S>>(&buf.as_slice()[..snapshot_len]) + .map_err(|_| Error::with_msg(OutOfDisk, "deserialize snapshot1 failed")) + .map(Arc::new) + } + Err(_) => Err(Error::with_msg(NotFound, "failed to read snapshot1")), + }; + + // Recover `CryptoBlob` if one of them is corrupted. + let snapshots_res = match (snapshot0_res.is_ok(), snapshot1_res.is_ok()) { + (true, false) => { + blob1.recover_from(&blob0)?; + [&snapshot0_res, &snapshot0_res] + } + (false, true) => { + blob0.recover_from(&blob1)?; + [&snapshot1_res, &snapshot1_res] + } + (true, true) => [&snapshot0_res, &snapshot1_res], + (false, false) => return_errno_with_msg!( + NotFound, + "both snapshots are unable to read, recover failed" + ), + }; + + // Determine the latest snapshot and its index + let version0 = blob0.version_id().unwrap(); + let version1 = blob1.version_id().unwrap(); + let (snapshot_res, latest_index) = match Self::DEFAULT_LATEST_INDEX { + // If both `VersionId` are the same, we consider `DEFAULT_LATEST_INDEX` + // as the `latest_index`. + 0 | 1 if version0 == version1 => ( + snapshots_res[Self::DEFAULT_LATEST_INDEX], + Self::DEFAULT_LATEST_INDEX, + ), + 0 if version0 + 1 == version1 => (snapshots_res[1], 1), + 1 if version1 + 1 == version0 => (snapshots_res[0], 0), + _ => return_errno_with_msg!(InvalidArgs, "invalid latest snapshot index or version id"), + }; + let snapshot = snapshot_res.as_ref().unwrap().clone(); + Ok(Self { + blobs: [blob0, blob1], + latest_index, + buf, + snapshot, + }) + } + + /// Persists the latest snapshot. + pub fn persist(&mut self, latest: Arc<Snapshot<S>>) -> Result<()> { + // Serialize the latest snapshot. + let buf = postcard::to_slice(latest.as_ref(), self.buf.as_mut_slice()) + .map_err(|_| Error::with_msg(OutOfDisk, "serialize current state failed"))?; + + // Persist the latest snapshot to `CryptoBlob`. + let index = (self.latest_index + 1) % 2; // switch the `latest_index` + self.blobs[index].write(buf)?; + self.latest_index = index; + self.snapshot = latest; + Ok(()) + } + + /// Returns the latest `Snapshot<S>`. + pub fn latest_snapshot(&self) -> Arc<Snapshot<S>> { + self.snapshot.clone() + } + + /// Returns the MAC of latest snapshot. + pub fn latest_mac(&self) -> Mac { + self.blobs[self.latest_index].current_mac().unwrap() + } + + /// Returns the number of blocks reserved for storing a snapshot `CryptoBlob`. + pub fn nblocks(&self) -> usize { + self.blobs[0].nblocks() + } + + /// Returns the keys of two `CryptoBlob`. + pub fn keys(&self) -> [Key; 2] { + [*self.blobs[0].key(), *self.blobs[1].key()] + } +} + +/// A journal record in an edit journal. +enum Record<E: Edit<S>, S> { + /// A record refers to a state snapshot of a specific MAC. + Version(Mac), + /// A record that contains an edit group. + Edit(EditGroup<E, S>), +} + +impl<E: Edit<S>, S> Serialize for Record<E, S> { + fn serialize<Se>(&self, serializer: Se) -> core::result::Result<Se::Ok, Se::Error> + where + Se: serde::Serializer, + { + match *self { + Record::Version(ref mac) => { + serializer.serialize_newtype_variant("Record", 0, "Version", mac) + } + Record::Edit(ref edit) => { + serializer.serialize_newtype_variant("Record", 1, "Edit", edit) + } + } + } +} + +impl<'de, E: Edit<S>, S> Deserialize<'de> for Record<E, S> { + fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + enum Variants { + Version, + Edit, + } + + impl<'de> Deserialize<'de> for Variants { + fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + struct VariantVisitor; + + impl Visitor<'_> for VariantVisitor { + type Value = Variants; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + formatter.write_str("`Version` or `Edit`") + } + + fn visit_u32<E>(self, v: u32) -> core::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + match v { + 0 => Ok(Variants::Version), + 1 => Ok(Variants::Edit), + _ => Err(E::custom("unknown value")), + } + } + } + + deserializer.deserialize_identifier(VariantVisitor) + } + } + + struct RecordVisitor<E: Edit<S>, S> { + _p: PhantomData<(E, S)>, + } + + impl<'a, E: Edit<S>, S> Visitor<'a> for RecordVisitor<E, S> { + type Value = Record<E, S>; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + formatter.write_str("a journal record") + } + + fn visit_enum<A>(self, data: A) -> core::result::Result<Self::Value, A::Error> + where + A: serde::de::EnumAccess<'a>, + { + let (variant, data) = data.variant::<Variants>()?; + let record = match variant { + Variants::Version => { + let mac = data.newtype_variant::<Mac>()?; + Record::Version(mac) + } + Variants::Edit => { + let edit_group = data.newtype_variant::<EditGroup<E, S>>()?; + Record::Edit(edit_group) + } + }; + Ok(record) + } + } + + deserializer.deserialize_enum( + "Record", + &["Version", "Edit"], + RecordVisitor { _p: PhantomData }, + ) + } +} + +/// A buffer for writing journal records into an edit journal. +/// +/// The capacity of `WriteBuf` is equal to the (available) block size of +/// `CryptoChain`. Records that are written to an edit journal are first +/// be inserted into the `WriteBuf`. When the `WriteBuf` is full or almost full, +/// the buffer as a whole will be written to the underlying `CryptoChain`. +struct WriteBuf<E: Edit<S>, S: Sized> { + buf: Buf, + // The cursor for writing new records. + cursor: usize, + capacity: usize, + phantom: PhantomData<(E, S)>, +} + +impl<E: Edit<S>, S: Sized> WriteBuf<E, S> { + /// Creates a new instance. + pub fn new(capacity: usize) -> Self { + debug_assert!(capacity <= BLOCK_SIZE); + Self { + buf: Buf::alloc(1).unwrap(), + cursor: 0, + capacity, + phantom: PhantomData, + } + } + + /// Writes a record into the buffer. + pub fn write(&mut self, record: &Record<E, S>) -> Result<()> { + // Write the record at the beginning of the avail buffer + match postcard::to_slice(record, self.avail_buf()) { + Ok(serial_record) => { + self.cursor += serial_record.len(); + Ok(()) + } + Err(e) => { + if e != postcard::Error::SerializeBufferFull { + panic!( + "Errors (except SerializeBufferFull) are not expected: {}", + e + ); + } + return_errno_with_msg!(OutOfDisk, "no space for new Record in WriteBuf"); + } + } + } + + /// Clear all records in the buffer. + pub fn clear(&mut self) { + self.cursor = 0; + } + + /// Returns a slice containing the data in the write buffer. + pub fn as_slice(&self) -> &[u8] { + &self.buf.as_slice()[..self.cursor] + } + + fn avail_len(&self) -> usize { + self.capacity - self.cursor + } + + fn avail_buf(&mut self) -> &mut [u8] { + &mut self.buf.as_mut_slice()[self.cursor..self.capacity] + } +} + +/// A byte slice containing serialized edit records. +/// +/// The slice allows deserializing and iterates the contained edit records. +struct RecordSlice<'a, E, S> { + buf: &'a [u8], + phantom: PhantomData<(E, S)>, + any_error: bool, +} + +impl<'a, E: Edit<S>, S> RecordSlice<'a, E, S> { + /// Create a new slice of edit records in serialized form. + pub fn new(buf: &'a [u8]) -> Self { + Self { + buf, + phantom: PhantomData, + any_error: false, + } + } + + /// Returns if any error occurs while deserializing the records. + pub fn any_error(&self) -> bool { + self.any_error + } +} + +impl<E: Edit<S>, S> Iterator for RecordSlice<'_, E, S> { + type Item = Record<E, S>; + + fn next(&mut self) -> Option<Record<E, S>> { + match postcard::take_from_bytes::<Record<E, S>>(self.buf) { + Ok((record, left)) => { + self.buf = left; + Some(record) + } + Err(_) => { + if !self.buf.is_empty() { + self.any_error = true; + } + None + } + } + } +} + +/// A compaction policy, which decides when is the good timing for compacting +/// the edits in an edit journal. +pub trait CompactPolicy<E: Edit<S>, S> { + /// Called when an edit group is committed. + /// + /// As more edits are accumulated, the compaction policy is more likely to + /// decide that now is the time to compact. + fn on_commit_edits(&mut self, edits: &EditGroup<E, S>); + + /// Called when some edits are appended to `CryptoChain`. + /// + /// The `appended_blocks` indicates how many blocks of journal area are + /// occupied by those edits. + fn on_append_journal(&mut self, appended_blocks: usize); + + /// Returns whether now is a good timing for compaction. + fn should_compact(&self) -> bool; + + /// Reset the state, as if no edits have ever been added. + /// + /// The `compacted_blocks` indicates how many blocks are reclaimed during + /// this compaction. + fn done_compact(&mut self, compacted_blocks: usize); +} + +/// A never-do-compaction policy. Mostly useful for testing. +pub struct NeverCompactPolicy; + +impl<E: Edit<S>, S> CompactPolicy<E, S> for NeverCompactPolicy { + fn on_commit_edits(&mut self, _edits: &EditGroup<E, S>) {} + + fn on_append_journal(&mut self, _appended_nblocks: usize) {} + + fn should_compact(&self) -> bool { + false + } + + fn done_compact(&mut self, _compacted_blocks: usize) {} +} + +/// A compaction policy, triggered when there's no-space left for new edits. +pub struct DefaultCompactPolicy { + used_blocks: usize, + total_blocks: usize, +} + +impl DefaultCompactPolicy { + /// Constructs a `DefaultCompactPolicy`. + /// + /// It is initialized via the total number of blocks of `EditJournal` and state. + pub fn new<D: BlockSet>(disk_nblocks: usize, state_max_nbytes: usize) -> Self { + // Calculate the blocks used by `Snapshot`s. + let snapshot_bytes = + CryptoBlob::<D>::HEADER_NBYTES + state_max_nbytes + Snapshot::<D>::meta_len(); + let snapshot_blocks = snapshot_bytes.div_ceil(BLOCK_SIZE); + debug_assert!( + snapshot_blocks * 2 < disk_nblocks, + "the number of blocks of journal area are too small" + ); + + Self { + used_blocks: 0, + total_blocks: disk_nblocks - snapshot_blocks * 2, + } + } + + /// Constructs a `DefaultCompactPolicy` from `EditJournalMeta`. + pub fn from_meta(meta: &EditJournalMeta) -> Self { + Self { + used_blocks: 0, + total_blocks: meta.journal_area_nblocks, + } + } +} + +impl<E: Edit<S>, S> CompactPolicy<E, S> for DefaultCompactPolicy { + fn on_commit_edits(&mut self, _edits: &EditGroup<E, S>) {} + + fn on_append_journal(&mut self, nblocks: usize) { + self.used_blocks += nblocks; + } + + fn should_compact(&self) -> bool { + self.used_blocks >= self.total_blocks + } + + fn done_compact(&mut self, compacted_blocks: usize) { + debug_assert!(self.used_blocks >= compacted_blocks); + self.used_blocks -= compacted_blocks; + } +} + +#[cfg(test)] +mod tests { + use ostd_pod::Pod; + use serde::{Deserialize, Serialize}; + + use super::{ + CompactPolicy, DefaultCompactPolicy, Edit, EditGroup, EditJournal, Record, RecordSlice, + WriteBuf, + }; + use crate::{ + layers::{ + bio::{BlockSet, MemDisk, BLOCK_SIZE}, + crypto::Mac, + }, + prelude::*, + }; + + #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] + struct XEdit { + x: i32, + } + + #[derive(Serialize, Deserialize, Clone, Debug)] + struct XState { + sum: i32, + } + + impl Edit<XState> for XEdit { + fn apply_to(&self, state: &mut XState) { + (*state).sum += self.x; + } + } + + /// A threshold-based compact policy. The `threshold` is the upper limit + /// of the number of `CryptoChain::append`. + /// + /// # Safety + /// + /// The `EditJournal` must have enough space to persist the threshold + /// of appended blocks, to avoid overlapping. + struct ThresholdPolicy { + appended: usize, + threshold: usize, + } + + impl ThresholdPolicy { + pub fn new(threshold: usize) -> Self { + Self { + appended: 0, + threshold, + } + } + } + + impl CompactPolicy<XEdit, XState> for ThresholdPolicy { + fn on_commit_edits(&mut self, _edits: &EditGroup<XEdit, XState>) {} + + fn on_append_journal(&mut self, nblocks: usize) { + self.appended += nblocks; + } + + fn should_compact(&self) -> bool { + self.appended >= self.threshold + } + + fn done_compact(&mut self, _compacted_blocks: usize) { + self.appended = 0; + } + } + + #[test] + fn serde_record() { + let mut buf = [0u8; 64]; + let mut offset = 0; + + // Add `Record::Edit` to buffer. + let mut group = EditGroup::<XEdit, XState>::new(); + for x in 0..10 { + let edit = XEdit { x }; + group.push(edit); + } + let mut state = XState { sum: 0 }; + group.apply_to(&mut state); + let group_len = group.len(); + let edit = Record::<XEdit, XState>::Edit(group); + let ser = postcard::to_slice(&edit, &mut buf).unwrap(); + println!("serialized edit_group len: {} data: {:?}", ser.len(), ser); + offset += ser.len(); + + // Add `Record::Version` to buffer. + let mac = Mac::random(); + let version = Record::<XEdit, XState>::Version(mac); + let ser = postcard::to_slice(&version, &mut buf[offset..]).unwrap(); + println!("serialize edit_group len: {} data: {:?}", ser.len(), ser); + offset += ser.len(); + + // Deserialize all `Record`. + let record_slice = RecordSlice::<XEdit, XState>::new(&buf[..offset]); + for record in record_slice { + match record { + Record::Version(m) => { + println!("slice version_mac: {:?}", m); + assert_eq!(m.as_bytes(), mac.as_bytes()); + } + Record::Edit(group) => { + println!("slice edit_group len: {}", group.len()); + assert_eq!(group.len(), group_len); + } + } + } + } + + #[test] + fn write_buf() { + let mut write_buf = WriteBuf::<XEdit, XState>::new(BLOCK_SIZE); + + assert_eq!(write_buf.cursor, 0); + assert_eq!(write_buf.capacity, BLOCK_SIZE); + let mut version = 0; + while write_buf.write(&Record::Version(Mac::random())).is_ok() { + version += 1; + let mut group = EditGroup::new(); + for x in 0..version { + let edit = XEdit { x: x as i32 }; + group.push(edit); + } + if write_buf.write(&Record::Edit(group)).is_err() { + break; + } + } + assert_ne!(write_buf.cursor, 0); + + let record_slice = + RecordSlice::<XEdit, XState>::new(&write_buf.buf.as_slice()[..write_buf.cursor]); + let mut version = 0; + for record in record_slice { + match record { + Record::Version(m) => { + println!("slice version_mac: {:?}", m); + version += 1; + } + Record::Edit(group) => { + println!("slice edit_group len: {}", group.len()); + assert_eq!(group.len(), version as usize); + } + } + } + write_buf.clear(); + assert_eq!(write_buf.cursor, 0); + } + + /// A test case for `EditJournal`. + /// + /// The `threshold` is used to control the compact frequency, see `ThresholdPolicy`. + /// The `commit_times` is used to control the number of `EditGroup` committed. + /// In addition, the `WriteBuf` will append to `CryptoChain` every two commits. + fn append_and_recover(threshold: usize, commit_times: usize) { + let disk = MemDisk::create(64).unwrap(); + let mut journal = EditJournal::format( + disk.subset(0..16).unwrap(), + XState { sum: 0 }, + core::mem::size_of::<XState>() * 2, + ThresholdPolicy::new(threshold), + ) + .unwrap(); + let meta = journal.meta(); + assert_eq!(meta.snapshot_area_nblocks, 1); + assert_eq!(meta.journal_area_nblocks, 14); + { + println!("journaling started"); + // The `WriteBuf` could hold two `EditGroup` in this test, + // so we would lose those commit states in `WriteBuf`. + for _ in 0..commit_times { + for x in 0..1000 { + let edit = XEdit { x }; + journal.add(edit); + } + journal.commit(); + println!("state: {}", journal.state().sum); + } + }; + + journal.flush().unwrap(); + + let journal_disk = disk.subset(0..32).unwrap(); + let threshold_policy = ThresholdPolicy::new(1_000); + let recover = EditJournal::recover(journal_disk, &meta, threshold_policy).unwrap(); + println!("recover state: {}", recover.state().sum); + println!( + "journal chain block range {:?}", + recover.journal_chain.block_range() + ); + let append_times = (commit_times - 1) / 2; + println!("append times: {}", append_times); + assert_eq!( + recover.state().sum as usize, + (0 + 999) * 1000 / 2 * commit_times + ); + let compact_times = append_times / threshold; + println!("compact times: {}", compact_times); + } + + #[test] + fn edit_journal() { + // No compact. + append_and_recover(5, 1); + append_and_recover(5, 10); + + // Compact once. + append_and_recover(5, 11); + append_and_recover(5, 20); + + // Compact twice. + append_and_recover(5, 21); + + // Compact many times. + append_and_recover(5, 1000); + } + + /// A test case for `DefaultCompactPolicy`. + /// + /// The `commit_times` is used to control the number of `EditGroup` committed. + fn default_compact_policy_when_commit(commit_times: usize) { + let disk = MemDisk::create(16).unwrap(); + + let journal_disk = disk.subset(0..12).unwrap(); + let state_max_nbytes = core::mem::size_of::<XState>() * 2; + let compact_policy = + DefaultCompactPolicy::new::<MemDisk>(journal_disk.nblocks(), state_max_nbytes); + let mut journal: EditJournal<XEdit, XState, MemDisk, DefaultCompactPolicy> = + EditJournal::format( + journal_disk, + XState { sum: 0 }, + state_max_nbytes, + compact_policy, + ) + .unwrap(); + let meta = journal.meta(); + assert_eq!(meta.snapshot_area_nblocks, 1); + assert_eq!(meta.journal_area_nblocks, 10); + { + println!("journaling started"); + // The `WriteBuf` could hold two `EditGroup` in this test. + for _ in 0..commit_times { + for x in 0..1000 { + let edit = XEdit { x }; + journal.add(edit); + } + journal.commit(); + println!("state: {}", journal.state().sum); + } + }; + + journal.flush().unwrap(); + + let journal_disk = disk.subset(0..12).unwrap(); + let compact_policy = DefaultCompactPolicy::from_meta(&meta); + let recover: EditJournal<XEdit, XState, MemDisk, DefaultCompactPolicy> = + EditJournal::recover(journal_disk, &meta, compact_policy).unwrap(); + println!("recover state: {}", recover.state().sum); + assert_eq!( + recover.state().sum as usize, + (0 + 999) * 1000 / 2 * commit_times + ); + } + + #[test] + fn default_compact_policy() { + default_compact_policy_when_commit(0); + default_compact_policy_when_commit(10); + default_compact_policy_when_commit(100); + default_compact_policy_when_commit(1000); + } +} diff --git a/kernel/comps/mlsdisk/src/layers/2-edit/mod.rs b/kernel/comps/mlsdisk/src/layers/2-edit/mod.rs new file mode 100644 index 00000000..5837f3d0 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/2-edit/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! The layer of edit journal. + +mod edits; +mod journal; + +pub use self::{ + edits::{Edit, EditGroup}, + journal::{ + CompactPolicy, DefaultCompactPolicy, EditJournal, EditJournalMeta, NeverCompactPolicy, + }, +}; diff --git a/kernel/comps/mlsdisk/src/layers/3-log/chunk.rs b/kernel/comps/mlsdisk/src/layers/3-log/chunk.rs new file mode 100644 index 00000000..ea4f6a58 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/3-log/chunk.rs @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Chunk-based storage management. +//! +//! A chunk is a group of consecutive blocks. +//! As the size of a chunk is much greater than that of a block, +//! the number of chunks is naturally far smaller than that of blocks. +//! This makes it possible to keep all metadata for chunks in memory. +//! Thus, managing chunks is more efficient than managing blocks. +//! +//! The primary API provided by this module is chunk allocators, +//! `ChunkAlloc`, which tracks whether chunks are free or not. +//! +//! # Examples +//! +//! Chunk allocators are used within transactions. +//! +//! ``` +//! fn alloc_chunks(chunk_alloc: &ChunkAlloc, num_chunks: usize) -> Option<Vec<ChunkId>> { +//! let mut tx = chunk_alloc.new_tx(); +//! let res: Option<Vec<ChunkId>> = tx.context(|| { +//! let mut chunk_ids = Vec::new(); +//! for _ in 0..num_chunks { +//! chunk_ids.push(chunk_alloc.alloc()?); +//! } +//! Some(chunk_ids) +//! }); +//! if res.is_some() { +//! tx.commit().ok()?; +//! } else { +//! tx.abort(); +//! } +//! res +//! } +//! ``` +//! +//! This above example showcases the power of transaction atomicity: +//! if anything goes wrong (e.g., allocation failures) during the transaction, +//! then the transaction can be aborted and all changes made to `chuck_alloc` +//! during the transaction will be rolled back automatically. +use serde::{Deserialize, Serialize}; + +use crate::{ + layers::edit::Edit, + os::{HashMap, Mutex}, + prelude::*, + tx::{CurrentTx, TxData, TxProvider}, + util::BitMap, +}; + +/// The ID of a chunk. +pub type ChunkId = usize; + +/// Number of blocks of a chunk. +pub const CHUNK_NBLOCKS: usize = 1024; +/// The chunk size is a multiple of the block size. +pub const CHUNK_SIZE: usize = CHUNK_NBLOCKS * BLOCK_SIZE; + +/// A chunk allocator tracks which chunks are free. +#[derive(Clone)] +pub struct ChunkAlloc { + state: Arc<Mutex<ChunkAllocState>>, + tx_provider: Arc<TxProvider>, +} + +impl ChunkAlloc { + /// Creates a chunk allocator that manages a specified number of + /// chunks (`capacity`). Initially, all chunks are free. + pub fn new(capacity: usize, tx_provider: Arc<TxProvider>) -> Self { + let state = ChunkAllocState::new(capacity); + Self::from_parts(state, tx_provider) + } + + /// Constructs a `ChunkAlloc` from its parts. + pub(super) fn from_parts(mut state: ChunkAllocState, tx_provider: Arc<TxProvider>) -> Self { + state.in_journal = false; + let new_self = Self { + state: Arc::new(Mutex::new(state)), + tx_provider, + }; + + // TX data + new_self + .tx_provider + .register_data_initializer(Box::new(ChunkAllocEdit::new)); + + // Commit handler + new_self.tx_provider.register_commit_handler({ + let state = new_self.state.clone(); + move |current: CurrentTx<'_>| { + let state = state.clone(); + current.data_with(move |edit: &ChunkAllocEdit| { + if edit.edit_table.is_empty() { + return; + } + + let mut state = state.lock(); + edit.apply_to(&mut state); + }); + } + }); + + // Abort handler + new_self.tx_provider.register_abort_handler({ + let state = new_self.state.clone(); + move |current: CurrentTx<'_>| { + let state = state.clone(); + current.data_with(move |edit: &ChunkAllocEdit| { + let mut state = state.lock(); + for chunk_id in edit.iter_allocated_chunks() { + state.dealloc(chunk_id); + } + }); + } + }); + + new_self + } + + /// Creates a new transaction for the chunk allocator. + pub fn new_tx(&self) -> CurrentTx<'_> { + self.tx_provider.new_tx() + } + + /// Allocates a chunk, returning its ID. + pub fn alloc(&self) -> Option<ChunkId> { + let chunk_id = { + let mut state = self.state.lock(); + state.alloc()? // Update global state immediately + }; + + let mut current_tx = self.tx_provider.current(); + current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| { + edit.alloc(chunk_id); + }); + + Some(chunk_id) + } + + /// Allocates `count` number of chunks. Returns IDs of newly-allocated + /// chunks, returns `None` if any allocation fails. + pub fn alloc_batch(&self, count: usize) -> Option<Vec<ChunkId>> { + let chunk_ids = { + let mut ids = Vec::with_capacity(count); + let mut state = self.state.lock(); + for _ in 0..count { + match state.alloc() { + Some(id) => ids.push(id), + None => { + ids.iter().for_each(|id| state.dealloc(*id)); + return None; + } + } + } + ids.sort_unstable(); + ids + }; + + let mut current_tx = self.tx_provider.current(); + current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| { + for chunk_id in &chunk_ids { + edit.alloc(*chunk_id); + } + }); + + Some(chunk_ids) + } + + /// Deallocates the chunk of a given ID. + /// + /// # Panic + /// + /// Deallocating a free chunk causes panic. + pub fn dealloc(&self, chunk_id: ChunkId) { + let mut current_tx = self.tx_provider.current(); + current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| { + let should_dealloc_now = edit.dealloc(chunk_id); + + if should_dealloc_now { + let mut state = self.state.lock(); + state.dealloc(chunk_id); + } + }); + } + + /// Deallocates the set of chunks of given IDs. + /// + /// # Panic + /// + /// Deallocating a free chunk causes panic. + pub fn dealloc_batch<I>(&self, chunk_ids: I) + where + I: Iterator<Item = ChunkId>, + { + let mut current_tx = self.tx_provider.current(); + current_tx.data_mut_with(|edit: &mut ChunkAllocEdit| { + let mut state = self.state.lock(); + for chunk_id in chunk_ids { + let should_dealloc_now = edit.dealloc(chunk_id); + + if should_dealloc_now { + state.dealloc(chunk_id); + } + } + }); + } + + /// Returns the capacity of the allocator, which is the number of chunks. + pub fn capacity(&self) -> usize { + self.state.lock().capacity() + } + + /// Returns the number of free chunks. + pub fn free_count(&self) -> usize { + self.state.lock().free_count() + } +} + +impl Debug for ChunkAlloc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = self.state.lock(); + f.debug_struct("ChunkAlloc") + .field("bitmap_free_count", &state.free_count) + .field("bitmap_next_free", &state.next_free) + .finish() + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Persistent State +//////////////////////////////////////////////////////////////////////////////// + +/// The persistent state of a chunk allocator. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChunkAllocState { + // A bitmap where each bit indicates whether a corresponding chunk + // has been allocated. + alloc_map: BitMap, + // The number of free chunks. + free_count: usize, + // The next free chunk Id. Used to narrow the scope of + // searching for free chunk IDs. + next_free: usize, + /// Whether the state is in the journal or not. + in_journal: bool, +} +// TODO: Separate persistent and volatile state of `ChunkAlloc` + +impl ChunkAllocState { + /// Creates a persistent state for managing chunks of the specified number. + /// Initially, all chunks are free. + pub fn new(capacity: usize) -> Self { + Self { + alloc_map: BitMap::repeat(false, capacity), + free_count: capacity, + next_free: 0, + in_journal: false, + } + } + + /// Creates a persistent state in the journal. The state in the journal and + /// the state that `RawLogStore` manages act differently on allocation and + /// edits' appliance. + pub fn new_in_journal(capacity: usize) -> Self { + Self { + alloc_map: BitMap::repeat(false, capacity), + free_count: capacity, + next_free: 0, + in_journal: true, + } + } + + /// Allocates a chunk, returning its ID. + pub fn alloc(&mut self) -> Option<ChunkId> { + let mut next_free = self.next_free; + if next_free == self.alloc_map.len() { + next_free = 0; + } + + let free_chunk_id = { + if let Some(chunk_id) = self.alloc_map.first_zero(next_free) { + chunk_id + } else { + self.alloc_map + .first_zero(0) + .expect("there must exists a zero") + } + }; + + self.alloc_map.set(free_chunk_id, true); + self.free_count -= 1; + self.next_free = free_chunk_id + 1; + + Some(free_chunk_id) + } + + /// Deallocates the chunk of a given ID. + /// + /// # Panic + /// + /// Deallocating a free chunk causes panic. + pub fn dealloc(&mut self, chunk_id: ChunkId) { + debug_assert!(self.alloc_map[chunk_id]); + self.alloc_map.set(chunk_id, false); + self.free_count += 1; + } + + /// Returns the total number of chunks. + pub fn capacity(&self) -> usize { + self.alloc_map.len() + } + + /// Returns the number of free chunks. + pub fn free_count(&self) -> usize { + self.free_count + } + + /// Returns whether a specific chunk is allocated. + pub fn is_chunk_allocated(&self, chunk_id: ChunkId) -> bool { + self.alloc_map[chunk_id] + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Persistent Edit +//////////////////////////////////////////////////////////////////////////////// + +/// A persistent edit to the state of a chunk allocator. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChunkAllocEdit { + edit_table: HashMap<ChunkId, ChunkEdit>, +} + +/// The smallest unit of a persistent edit to the +/// state of a chunk allocator, which is +/// a chunk being either allocated or deallocated. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +enum ChunkEdit { + Alloc, + Dealloc, +} + +impl ChunkAllocEdit { + /// Creates a new empty edit table. + pub fn new() -> Self { + Self { + edit_table: HashMap::new(), + } + } + + /// Records a chunk allocation in the edit. + pub fn alloc(&mut self, chunk_id: ChunkId) { + let old_edit = self.edit_table.insert(chunk_id, ChunkEdit::Alloc); + + // There must be a logical error if an edit has been recorded + // for the chunk. If the chunk edit is `ChunkEdit::Alloc`, then + // it is double allocations. If the chunk edit is `ChunkEdit::Dealloc`, + // then such deallocations can only take effect after the edit is + // committed. Thus, it is impossible to allocate the chunk again now. + assert!(old_edit.is_none()); + } + + /// Records a chunk deallocation in the edit. + /// + /// The return value indicates whether the chunk being deallocated + /// is previously recorded in the edit as being allocated. + /// If so, the chunk can be deallocated in the `ChunkAllocState`. + pub fn dealloc(&mut self, chunk_id: ChunkId) -> bool { + match self.edit_table.get(&chunk_id) { + None => { + self.edit_table.insert(chunk_id, ChunkEdit::Dealloc); + false + } + Some(&ChunkEdit::Alloc) => { + self.edit_table.remove(&chunk_id); + true + } + Some(&ChunkEdit::Dealloc) => { + panic!("a chunk must not be deallocated twice"); + } + } + } + + /// Returns an iterator over all allocated chunks. + pub fn iter_allocated_chunks(&self) -> impl Iterator<Item = ChunkId> + '_ { + self.edit_table.iter().filter_map(|(id, edit)| { + if *edit == ChunkEdit::Alloc { + Some(*id) + } else { + None + } + }) + } + + pub fn is_empty(&self) -> bool { + self.edit_table.is_empty() + } +} + +impl Edit<ChunkAllocState> for ChunkAllocEdit { + fn apply_to(&self, state: &mut ChunkAllocState) { + let mut to_be_deallocated = Vec::new(); + for (&chunk_id, chunk_edit) in &self.edit_table { + match chunk_edit { + ChunkEdit::Alloc => { + if state.in_journal { + let _allocated_id = state.alloc().unwrap(); + } + // Except journal, nothing needs to be done + } + ChunkEdit::Dealloc => { + to_be_deallocated.push(chunk_id); + } + } + } + for chunk_id in to_be_deallocated { + state.dealloc(chunk_id); + } + } +} + +impl TxData for ChunkAllocEdit {} + +#[cfg(test)] +mod tests { + use super::*; + + fn new_chunk_alloc() -> ChunkAlloc { + let cap = 1024_usize; + let tx_provider = TxProvider::new(); + let chunk_alloc = ChunkAlloc::new(cap, tx_provider); + assert_eq!(chunk_alloc.capacity(), cap); + assert_eq!(chunk_alloc.free_count(), cap); + chunk_alloc + } + + fn do_alloc_dealloc_tx(chunk_alloc: &ChunkAlloc, alloc_cnt: usize, dealloc_cnt: usize) -> Tx { + debug_assert!(alloc_cnt <= chunk_alloc.capacity() && dealloc_cnt <= alloc_cnt); + let mut tx = chunk_alloc.new_tx(); + tx.context(|| { + let chunk_id = chunk_alloc.alloc().unwrap(); + let chunk_ids = chunk_alloc.alloc_batch(alloc_cnt - 1).unwrap(); + let allocated_chunk_ids: Vec<ChunkId> = core::iter::once(chunk_id) + .chain(chunk_ids.into_iter()) + .collect(); + + chunk_alloc.dealloc(allocated_chunk_ids[0]); + chunk_alloc.dealloc_batch( + allocated_chunk_ids[alloc_cnt - dealloc_cnt + 1..alloc_cnt] + .iter() + .cloned(), + ); + }); + tx + } + + #[test] + fn chunk_alloc_dealloc_tx_commit() -> Result<()> { + let chunk_alloc = new_chunk_alloc(); + let cap = chunk_alloc.capacity(); + let (alloc_cnt, dealloc_cnt) = (cap, cap); + + let mut tx = do_alloc_dealloc_tx(&chunk_alloc, alloc_cnt, dealloc_cnt); + tx.commit()?; + assert_eq!(chunk_alloc.free_count(), cap - alloc_cnt + dealloc_cnt); + Ok(()) + } + + #[test] + fn chunk_alloc_dealloc_tx_abort() -> Result<()> { + let chunk_alloc = new_chunk_alloc(); + let cap = chunk_alloc.capacity(); + let (alloc_cnt, dealloc_cnt) = (cap / 2, cap / 4); + + let mut tx = do_alloc_dealloc_tx(&chunk_alloc, alloc_cnt, dealloc_cnt); + tx.abort(); + assert_eq!(chunk_alloc.free_count(), cap); + Ok(()) + } +} diff --git a/kernel/comps/mlsdisk/src/layers/3-log/mod.rs b/kernel/comps/mlsdisk/src/layers/3-log/mod.rs new file mode 100644 index 00000000..eb13c3f0 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/3-log/mod.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! The layer of transactional logging. +//! +//! `TxLogStore` is a transactional, log-oriented file system. +//! It supports creating, deleting, listing, reading, and writing `TxLog`s. +//! Each `TxLog` is an append-only log, and assigned an unique `TxLogId`. +//! All `TxLogStore`'s APIs should be called within transactions (`TX`). + +mod chunk; +mod raw_log; +mod tx_log; + +pub use self::tx_log::{TxLog, TxLogId, TxLogStore}; diff --git a/kernel/comps/mlsdisk/src/layers/3-log/raw_log.rs b/kernel/comps/mlsdisk/src/layers/3-log/raw_log.rs new file mode 100644 index 00000000..36660404 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/3-log/raw_log.rs @@ -0,0 +1,1235 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A store of raw (untrusted) logs. +//! +//! `RawLogStore<D>` allows creating, deleting, reading and writing +//! `RawLog<D>`. Each raw log is uniquely identified by its ID (`RawLogId`). +//! Writing to a raw log is append only. +//! +//! `RawLogStore<D>` stores raw logs on a disk of `D: BlockSet`. +//! Internally, `RawLogStore<D>` manages the disk space with `ChunkAlloc` +//! so that the disk space can be allocated and deallocated in the units of +//! chunk. An allocated chunk belongs to exactly one raw log. And one raw log +//! may be backed by multiple chunks. The raw log is represented externally +//! as a `BlockLog`. +//! +//! # Examples +//! +//! Raw logs are manipulated and accessed within transactions. +//! +//! ``` +//! fn concat_logs<D>( +//! log_store: &RawLogStore<D>, +//! log_ids: &[RawLogId] +//! ) -> Result<RawLogId> { +//! let mut tx = log_store.new_tx(); +//! let res: Result<_> = tx.context(|| { +//! let mut buf = Buf::alloc(1)?; +//! let output_log = log_store.create_log()?; +//! for log_id in log_ids { +//! let input_log = log_store.open_log(log_id, false)?; +//! let input_len = input_log.nblocks(); +//! let mut pos = 0 as BlockId; +//! while pos < input_len { +//! input_log.read(pos, buf.as_mut())?; +//! output_log.append(buf.as_ref())?; +//! } +//! } +//! Ok(output_log.id()) +//! }); +//! if res.is_ok() { +//! tx.commit()?; +//! } else { +//! tx.abort(); +//! } +//! res +//! } +//! ``` +//! +//! If any error occurs (e.g., failures to open, read, or write a log) during +//! the transaction, then all prior changes to raw logs shall have no +//! effects. On the other hand, if the commit operation succeeds, then +//! all changes made in the transaction shall take effect as a whole. +//! +//! # Expected behaviors +//! +//! We provide detailed descriptions about the expected behaviors of raw log +//! APIs under transactions. +//! +//! 1. The local changes made (e.g., creations, deletions, writes) in a TX are +//! immediately visible to the TX, but not other TX until the TX is committed. +//! For example, a newly-created log within TX A is immediately usable within TX, +//! but becomes visible to other TX only until A is committed. +//! As another example, when a log is deleted within a TX, then the TX can no +//! longer open the log. But other concurrent TX can still open the log. +//! +//! 2. If a TX is aborted, then all the local changes made in the TX will be +//! discarded. +//! +//! 3. At any given time, a log can have at most one writer TX. +//! A TX becomes the writer of a log when the log is opened with the write +//! permission in the TX. And it stops being the writer TX of the log only when +//! the TX is terminated (not when the log is closed within TX). +//! This single-writer rule avoids potential conflicts between concurrent +//! writing to the same log. +//! +//! 4. Log creation does not conflict with log deleation, read, or write as +//! every newly-created log is assigned a unique ID automatically. +//! +//! 4. Deleting a log does not affect any opened instance of the log in the TX +//! or other TX (similar to deleting a file in a UNIX-style system). +//! It is only until the deleting TX is committed and the last +//! instance of the log is closed shall the log be deleted and its disk space +//! be freed. +//! +//! 5. The TX commitment will not fail due to conflicts between concurrent +//! operations in different TX. +use core::sync::atomic::{AtomicUsize, Ordering}; + +use serde::{Deserialize, Serialize}; + +use super::chunk::{ChunkAlloc, ChunkId, CHUNK_NBLOCKS}; +use crate::{ + layers::{ + bio::{BlockLog, BlockSet, BufMut, BufRef}, + edit::Edit, + }, + os::{HashMap, HashSet, Mutex, MutexGuard}, + prelude::*, + tx::{CurrentTx, TxData, TxProvider}, + util::LazyDelete, +}; + +/// The ID of a raw log. +pub type RawLogId = u64; + +/// A store of raw logs. +pub struct RawLogStore<D> { + state: Arc<Mutex<State>>, + disk: D, + chunk_alloc: ChunkAlloc, // Mapping: ChunkId * CHUNK_NBLOCKS = disk position (BlockId) + tx_provider: Arc<TxProvider>, + weak_self: Weak<Self>, +} + +impl<D: BlockSet> RawLogStore<D> { + /// Creates a new store of raw logs, + /// given a chunk allocator and an untrusted disk. + pub fn new(disk: D, tx_provider: Arc<TxProvider>, chunk_alloc: ChunkAlloc) -> Arc<Self> { + Self::from_parts(RawLogStoreState::new(), disk, chunk_alloc, tx_provider) + } + + /// Constructs a `RawLogStore` from its parts. + pub(super) fn from_parts( + state: RawLogStoreState, + disk: D, + chunk_alloc: ChunkAlloc, + tx_provider: Arc<TxProvider>, + ) -> Arc<Self> { + let new_self = { + // Prepare lazy deletes first from persistent state + let lazy_deletes = { + let mut delete_table = HashMap::new(); + for (&log_id, log_entry) in state.log_table.iter() { + Self::add_lazy_delete(log_id, log_entry, &chunk_alloc, &mut delete_table) + } + delete_table + }; + + Arc::new_cyclic(|weak_self| Self { + state: Arc::new(Mutex::new(State::new(state, lazy_deletes))), + disk, + chunk_alloc, + tx_provider, + weak_self: weak_self.clone(), + }) + }; + + // TX data + new_self + .tx_provider + .register_data_initializer(Box::new(RawLogStoreEdit::new)); + + // Commit handler + new_self.tx_provider.register_commit_handler({ + let state = new_self.state.clone(); + let chunk_alloc = new_self.chunk_alloc.clone(); + move |current: CurrentTx<'_>| { + current.data_with(|edit: &RawLogStoreEdit| { + if edit.edit_table.is_empty() { + return; + } + + let mut state = state.lock(); + state.apply(edit); + + Self::add_lazy_deletes_for_created_logs(&mut state, edit, &chunk_alloc); + }); + let mut state = state.lock(); + Self::do_lazy_deletion(&mut state, ¤t); + } + }); + + new_self + } + + // Adds a lazy delete for the given log. + fn add_lazy_delete( + log_id: RawLogId, + log_entry: &RawLogEntry, + chunk_alloc: &ChunkAlloc, + delete_table: &mut HashMap<u64, Arc<LazyDelete<RawLogEntry>>>, + ) { + let log_entry = log_entry.clone(); + let chunk_alloc = chunk_alloc.clone(); + delete_table.insert( + log_id, + Arc::new(LazyDelete::new(log_entry, move |entry| { + chunk_alloc.dealloc_batch(entry.head.chunks.iter().cloned()) + })), + ); + } + + fn add_lazy_deletes_for_created_logs( + state: &mut State, + edit: &RawLogStoreEdit, + chunk_alloc: &ChunkAlloc, + ) { + for log_id in edit.iter_created_logs() { + let log_entry_opt = state.persistent.find_log(log_id); + if log_entry_opt.is_none() || state.lazy_deletes.contains_key(&log_id) { + continue; + } + + Self::add_lazy_delete( + log_id, + log_entry_opt.as_ref().unwrap(), + chunk_alloc, + &mut state.lazy_deletes, + ) + } + } + + // Do lazy deletions for the deleted logs in the current TX. + fn do_lazy_deletion(state: &mut State, current_tx: &CurrentTx) { + let deleted_logs = current_tx + .data_with(|edit: &RawLogStoreEdit| edit.iter_deleted_logs().collect::<Vec<_>>()); + + for log_id in deleted_logs { + let Some(lazy_delete) = state.lazy_deletes.remove(&log_id) else { + // Other concurrent TXs have deleted the same log + continue; + }; + LazyDelete::delete(&lazy_delete); + } + } + + /// Creates a new transaction for `RawLogStore`. + pub fn new_tx(&self) -> CurrentTx<'_> { + self.tx_provider.new_tx() + } + + /// Syncs all the data managed by `RawLogStore` for persistence. + pub fn sync(&self) -> Result<()> { + // Do nothing, leave the disk sync to `TxLogStore` + Ok(()) + } + + /// Creates a new raw log with a new log ID. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn create_log(&self) -> Result<RawLog<D>> { + let mut state = self.state.lock(); + let new_log_id = state.alloc_log_id(); + state + .add_to_write_set(new_log_id) + .expect("created log can't appear in write set"); + + let mut current_tx = self.tx_provider.current(); + current_tx.data_mut_with(|edit: &mut RawLogStoreEdit| { + edit.create_log(new_log_id); + }); + + Ok(RawLog { + log_id: new_log_id, + log_entry: None, + log_store: self.weak_self.upgrade().unwrap(), + tx_provider: self.tx_provider.clone(), + lazy_delete: None, + append_pos: AtomicUsize::new(0), + can_append: true, + }) + } + + /// Opens the raw log of a given ID. + /// + /// For any log at any time, there can be at most one TX that opens the log + /// in the appendable mode. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn open_log(&self, log_id: u64, can_append: bool) -> Result<RawLog<D>> { + let mut state = self.state.lock(); + // Must check lazy deletes first in case there is concurrent deletion + let lazy_delete = state + .lazy_deletes + .get(&log_id) + .ok_or(Error::with_msg(NotFound, "raw log already been deleted"))? + .clone(); + let mut current_tx = self.tx_provider.current(); + + let log_entry_opt = state.persistent.find_log(log_id); + // The log is already created by other TX + if let Some(log_entry) = log_entry_opt.as_ref() { + if can_append { + // Prevent other TX from opening this log in the append mode + state.add_to_write_set(log_id)?; + + // If the log is open in the append mode, edit must be prepared + current_tx.data_mut_with(|edit: &mut RawLogStoreEdit| { + edit.open_log(log_id, log_entry); + }); + } + } + // The log must has been created by this TX + else { + let is_log_created = + current_tx.data_mut_with(|edit: &mut RawLogStoreEdit| edit.is_log_created(log_id)); + if !is_log_created { + return_errno_with_msg!(NotFound, "raw log not found"); + } + } + + let append_pos: BlockId = log_entry_opt + .as_ref() + .map(|entry| entry.head.num_blocks as _) + .unwrap_or(0); + Ok(RawLog { + log_id, + log_entry: log_entry_opt.map(|entry| Arc::new(Mutex::new(entry.clone()))), + log_store: self.weak_self.upgrade().unwrap(), + tx_provider: self.tx_provider.clone(), + lazy_delete: Some(lazy_delete), + append_pos: AtomicUsize::new(append_pos), + can_append, + }) + } + + /// Deletes the raw log of a given ID. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn delete_log(&self, log_id: RawLogId) -> Result<()> { + let mut current_tx = self.tx_provider.current(); + + // Free tail chunks + let tail_chunks = + current_tx.data_mut_with(|edit: &mut RawLogStoreEdit| edit.delete_log(log_id)); + if let Some(chunks) = tail_chunks { + self.chunk_alloc.dealloc_batch(chunks.iter().cloned()); + } + // Leave freeing head chunks to lazy delete + + self.state.lock().remove_from_write_set(log_id); + Ok(()) + } +} + +impl<D> Debug for RawLogStore<D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = self.state.lock(); + f.debug_struct("RawLogStore") + .field("persistent_log_table", &state.persistent.log_table) + .field("next_free_log_id", &state.next_free_log_id) + .field("write_set", &state.write_set) + .field("chunk_alloc", &self.chunk_alloc) + .finish() + } +} + +/// A raw (untrusted) log. +pub struct RawLog<D> { + log_id: RawLogId, + log_entry: Option<Arc<Mutex<RawLogEntry>>>, + log_store: Arc<RawLogStore<D>>, + tx_provider: Arc<TxProvider>, + lazy_delete: Option<Arc<LazyDelete<RawLogEntry>>>, + append_pos: AtomicUsize, + can_append: bool, +} + +/// A reference (handle) to a raw log. +struct RawLogRef<'a, D> { + log_store: &'a RawLogStore<D>, + log_head: Option<RawLogHeadRef<'a>>, + log_tail: Option<RawLogTailRef<'a>>, +} + +/// A head reference (handle) to a raw log. +struct RawLogHeadRef<'a> { + entry: MutexGuard<'a, RawLogEntry>, +} + +/// A tail reference (handle) to a raw log. +struct RawLogTailRef<'a> { + log_id: RawLogId, + current: CurrentTx<'a>, +} + +impl<D: BlockSet> BlockLog for RawLog<D> { + /// Reads one or multiple blocks at a specified position. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + fn read(&self, pos: BlockId, buf: BufMut) -> Result<()> { + let log_ref = self.as_ref(); + log_ref.read(pos, buf) + } + + /// Appends one or multiple blocks at the end. + /// + /// This method must be called within a TX. Otherwise, this method panics. + fn append(&self, buf: BufRef) -> Result<BlockId> { + if !self.can_append { + return_errno_with_msg!(PermissionDenied, "raw log not in append mode"); + } + + let mut log_ref = self.as_ref(); + log_ref.append(buf)?; + + let nblocks = buf.nblocks(); + let pos = self.append_pos.fetch_add(nblocks, Ordering::Release); + Ok(pos) + } + + /// Ensures that blocks are persisted to the disk. + fn flush(&self) -> Result<()> { + // FIXME: Should we sync the disk here? + self.log_store.disk.flush()?; + Ok(()) + } + + /// Returns the number of blocks. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + fn nblocks(&self) -> usize { + let log_ref = self.as_ref(); + log_ref.nblocks() + } +} + +impl<D> RawLog<D> { + /// Gets the unique ID of raw log. + pub fn id(&self) -> RawLogId { + self.log_id + } + + /// Gets the reference (handle) of raw log. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + fn as_ref(&self) -> RawLogRef<'_, D> { + let log_head = self.log_entry.as_ref().map(|entry| RawLogHeadRef { + entry: entry.lock(), + }); + let log_tail = { + // Check if the log exists create or append edit + let has_valid_edit = self.tx_provider.current().data_mut_with( + |store_edit: &mut RawLogStoreEdit| -> bool { + let Some(edit) = store_edit.edit_table.get(&self.log_id) else { + return false; + }; + match edit { + RawLogEdit::Create(_) | RawLogEdit::Append(_) => true, + RawLogEdit::Delete => false, + } + }, + ); + if has_valid_edit { + Some(RawLogTailRef { + log_id: self.log_id, + current: self.tx_provider.current(), + }) + } else { + None + } + }; + + RawLogRef { + log_store: &self.log_store, + log_head, + log_tail, + } + } +} + +impl<D> Drop for RawLog<D> { + fn drop(&mut self) { + if self.can_append { + self.log_store + .state + .lock() + .remove_from_write_set(self.log_id); + } + } +} + +impl<D> Debug for RawLog<D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RawLog") + .field("log_id", &self.log_id) + .field("log_entry", &self.log_entry) + .field("append_pos", &self.append_pos) + .field("can_append", &self.can_append) + .finish() + } +} + +impl<D: BlockSet> RawLogRef<'_, D> { + /// Reads one or multiple blocks at a specified position of the log. + /// First head then tail if necessary. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn read(&self, mut pos: BlockId, mut buf: BufMut) -> Result<()> { + let mut nblocks = buf.nblocks(); + let mut buf_slice = buf.as_mut_slice(); + + let head_len = self.head_len(); + let tail_len = self.tail_len(); + let total_len = head_len + tail_len; + + if pos + nblocks > total_len { + return_errno_with_msg!(InvalidArgs, "do not allow short read"); + } + + let disk = &self.log_store.disk; + // Read from the head if possible and necessary + let head_opt = &self.log_head; + if let Some(head) = head_opt + && pos < head_len + { + let num_read = nblocks.min(head_len - pos); + + let read_buf = BufMut::try_from(&mut buf_slice[..num_read * BLOCK_SIZE])?; + head.read(pos, read_buf, &disk)?; + + pos += num_read; + nblocks -= num_read; + buf_slice = &mut buf_slice[num_read * BLOCK_SIZE..]; + } + if nblocks == 0 { + return Ok(()); + } + + // Read from the tail if possible and necessary + let tail_opt = &self.log_tail; + if let Some(tail) = tail_opt + && pos >= head_len + { + let num_read = nblocks.min(total_len - pos); + let read_buf = BufMut::try_from(&mut buf_slice[..num_read * BLOCK_SIZE])?; + + tail.read(pos - head_len, read_buf, &disk)?; + } + Ok(()) + } + + /// Appends one or multiple blocks at the end (to the tail). + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn append(&mut self, buf: BufRef) -> Result<()> { + let append_nblocks = buf.nblocks(); + let log_tail = self + .log_tail + .as_mut() + .expect("raw log must be opened in append mode"); + + // Allocate new chunks if necessary + let new_chunks_opt = { + let chunks_needed = log_tail.calc_needed_chunks(append_nblocks); + if chunks_needed > 0 { + let chunk_ids = self + .log_store + .chunk_alloc + .alloc_batch(chunks_needed) + .ok_or(Error::with_msg(OutOfMemory, "chunk allocation failed"))?; + Some(chunk_ids) + } else { + None + } + }; + + if let Some(new_chunks) = new_chunks_opt { + log_tail.tail_mut_with(|tail: &mut RawLogTail| { + tail.chunks.extend(new_chunks); + }); + } + + log_tail.append(buf, &self.log_store.disk)?; + + // Update tail metadata + log_tail.tail_mut_with(|tail: &mut RawLogTail| { + tail.num_blocks += append_nblocks as u32; + }); + Ok(()) + } + + /// Returns the number of blocks. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn nblocks(&self) -> usize { + self.head_len() + self.tail_len() + } + + fn head_len(&self) -> usize { + self.log_head.as_ref().map_or(0, |head| head.len()) + } + + fn tail_len(&self) -> usize { + self.log_tail.as_ref().map_or(0, |tail| tail.len()) + } +} + +impl RawLogHeadRef<'_> { + pub fn len(&self) -> usize { + self.entry.head.num_blocks as _ + } + + pub fn read<D: BlockSet>(&self, offset: BlockId, mut buf: BufMut, disk: &D) -> Result<()> { + let nblocks = buf.nblocks(); + debug_assert!(offset + nblocks <= self.entry.head.num_blocks as _); + + let prepared_blocks = self.prepare_blocks(offset, nblocks); + debug_assert_eq!(prepared_blocks.len(), nblocks); + + // Batch read + // Note that `prepared_blocks` are not always sorted + let mut offset = 0; + for consecutive_blocks in prepared_blocks.chunk_by(|b1, b2| b2.saturating_sub(*b1) == 1) { + let len = consecutive_blocks.len(); + let first_bid = *consecutive_blocks.first().unwrap(); + let buf_slice = + &mut buf.as_mut_slice()[offset * BLOCK_SIZE..(offset + len) * BLOCK_SIZE]; + disk.read(first_bid, BufMut::try_from(buf_slice).unwrap())?; + offset += len; + } + + Ok(()) + } + + /// Collect and prepare a set of consecutive blocks in head for a read request. + pub fn prepare_blocks(&self, mut offset: BlockId, nblocks: usize) -> Vec<BlockId> { + let mut res_blocks = Vec::with_capacity(nblocks); + let chunks = &self.entry.head.chunks; + + while res_blocks.len() != nblocks { + let curr_chunk_idx = offset / CHUNK_NBLOCKS; + let curr_chunk_inner_offset = offset % CHUNK_NBLOCKS; + + res_blocks.push(chunks[curr_chunk_idx] * CHUNK_NBLOCKS + curr_chunk_inner_offset); + offset += 1; + } + + res_blocks + } +} + +impl RawLogTailRef<'_> { + /// Apply given function to the immutable tail. + pub fn tail_with<F, R>(&self, f: F) -> R + where + F: FnOnce(&RawLogTail) -> R, + { + self.current.data_with(|store_edit: &RawLogStoreEdit| -> R { + let edit = store_edit.edit_table.get(&self.log_id).unwrap(); + match edit { + RawLogEdit::Create(create) => f(&create.tail), + RawLogEdit::Append(append) => f(&append.tail), + RawLogEdit::Delete => unreachable!(), + } + }) + } + + /// Apply given function to the mutable tail. + pub fn tail_mut_with<F, R>(&mut self, f: F) -> R + where + F: FnOnce(&mut RawLogTail) -> R, + { + self.current + .data_mut_with(|store_edit: &mut RawLogStoreEdit| -> R { + let edit = store_edit.edit_table.get_mut(&self.log_id).unwrap(); + match edit { + RawLogEdit::Create(create) => f(&mut create.tail), + RawLogEdit::Append(append) => f(&mut append.tail), + RawLogEdit::Delete => unreachable!(), + } + }) + } + + pub fn len(&self) -> usize { + self.tail_with(|tail: &RawLogTail| tail.num_blocks as _) + } + + pub fn read<D: BlockSet>(&self, offset: BlockId, mut buf: BufMut, disk: &D) -> Result<()> { + let nblocks = buf.nblocks(); + let tail_nblocks = self.len(); + debug_assert!(offset + nblocks <= tail_nblocks); + + let prepared_blocks = self.prepare_blocks(offset, nblocks); + debug_assert_eq!(prepared_blocks.len(), nblocks); + + // Batch read + // Note that `prepared_blocks` are not always sorted + let mut offset = 0; + for consecutive_blocks in prepared_blocks.chunk_by(|b1, b2| b2.saturating_sub(*b1) == 1) { + let len = consecutive_blocks.len(); + let first_bid = *consecutive_blocks.first().unwrap(); + let buf_slice = + &mut buf.as_mut_slice()[offset * BLOCK_SIZE..(offset + len) * BLOCK_SIZE]; + disk.read(first_bid, BufMut::try_from(buf_slice).unwrap())?; + offset += len; + } + + Ok(()) + } + + pub fn append<D: BlockSet>(&self, buf: BufRef, disk: &D) -> Result<()> { + let nblocks = buf.nblocks(); + + let prepared_blocks = self.prepare_blocks(self.len() as _, nblocks); + debug_assert_eq!(prepared_blocks.len(), nblocks); + + // Batch write + // Note that `prepared_blocks` are not always sorted + let mut offset = 0; + for consecutive_blocks in prepared_blocks.chunk_by(|b1, b2| b2.saturating_sub(*b1) == 1) { + let len = consecutive_blocks.len(); + let first_bid = *consecutive_blocks.first().unwrap(); + let buf_slice = &buf.as_slice()[offset * BLOCK_SIZE..(offset + len) * BLOCK_SIZE]; + disk.write(first_bid, BufRef::try_from(buf_slice).unwrap())?; + offset += len; + } + + Ok(()) + } + + // Calculate how many new chunks we need for an append request + pub fn calc_needed_chunks(&self, append_nblocks: usize) -> usize { + self.tail_with(|tail: &RawLogTail| { + let avail_blocks = tail.head_last_chunk_free_blocks as usize + + tail.chunks.len() * CHUNK_NBLOCKS + - tail.num_blocks as usize; + if append_nblocks > avail_blocks { + align_up(append_nblocks - avail_blocks, CHUNK_NBLOCKS) / CHUNK_NBLOCKS + } else { + 0 + } + }) + } + + /// Collect and prepare a set of consecutive blocks in tail for a read/append request. + fn prepare_blocks(&self, mut offset: BlockId, nblocks: usize) -> Vec<BlockId> { + self.tail_with(|tail: &RawLogTail| { + let mut res_blocks = Vec::with_capacity(nblocks); + let head_last_chunk_free_blocks = tail.head_last_chunk_free_blocks as usize; + + // Collect available blocks from the last chunk of the head first if necessary + if offset < head_last_chunk_free_blocks as _ { + for i in offset..head_last_chunk_free_blocks { + let avail_chunk = tail.head_last_chunk_id * CHUNK_NBLOCKS + + (CHUNK_NBLOCKS - head_last_chunk_free_blocks + i); + res_blocks.push(avail_chunk); + + if res_blocks.len() == nblocks { + return res_blocks; + } + } + + offset = 0; + } else { + offset -= head_last_chunk_free_blocks; + } + + // Collect available blocks from the tail first if necessary + let chunks = &tail.chunks; + while res_blocks.len() != nblocks { + let curr_chunk_idx = offset / CHUNK_NBLOCKS; + let curr_chunk_inner_offset = offset % CHUNK_NBLOCKS; + + res_blocks.push(chunks[curr_chunk_idx] * CHUNK_NBLOCKS + curr_chunk_inner_offset); + offset += 1; + } + + res_blocks + }) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Persistent State +//////////////////////////////////////////////////////////////////////////////// + +/// The volatile and persistent state of a `RawLogStore`. +struct State { + persistent: RawLogStoreState, + next_free_log_id: u64, + write_set: HashSet<RawLogId>, + lazy_deletes: HashMap<RawLogId, Arc<LazyDelete<RawLogEntry>>>, +} + +/// The persistent state of a `RawLogStore`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RawLogStoreState { + log_table: HashMap<RawLogId, RawLogEntry>, +} + +/// A log entry implies the persistent state of the raw log. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct RawLogEntry { + head: RawLogHead, +} + +/// A log head contains chunk metadata of a log's already-persist data. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct RawLogHead { + pub chunks: Vec<ChunkId>, + pub num_blocks: u32, +} + +impl State { + pub fn new( + persistent: RawLogStoreState, + lazy_deletes: HashMap<RawLogId, Arc<LazyDelete<RawLogEntry>>>, + ) -> Self { + let next_free_log_id = if let Some(max_log_id) = lazy_deletes.keys().max() { + max_log_id + 1 + } else { + 0 + }; + Self { + persistent: persistent.clone(), + next_free_log_id, + write_set: HashSet::new(), + lazy_deletes, + } + } + + pub fn apply(&mut self, edit: &RawLogStoreEdit) { + edit.apply_to(&mut self.persistent); + } + + pub fn alloc_log_id(&mut self) -> u64 { + let new_log_id = self.next_free_log_id; + self.next_free_log_id = self + .next_free_log_id + .checked_add(1) + .expect("64-bit IDs won't be exhausted even though IDs are not recycled"); + new_log_id + } + + pub fn add_to_write_set(&mut self, log_id: RawLogId) -> Result<()> { + let not_exists = self.write_set.insert(log_id); + if !not_exists { + // Obey single-writer rule + return_errno_with_msg!(PermissionDenied, "the raw log has more than one writer"); + } + Ok(()) + } + + pub fn remove_from_write_set(&mut self, log_id: RawLogId) { + let _is_removed = self.write_set.remove(&log_id); + // `_is_removed` may equal to `false` if the log has already been deleted + } +} + +impl RawLogStoreState { + pub fn new() -> Self { + Self { + log_table: HashMap::new(), + } + } + + pub fn create_log(&mut self, new_log_id: u64) { + let new_log_entry = RawLogEntry { + head: RawLogHead::new(), + }; + let already_exists = self.log_table.insert(new_log_id, new_log_entry).is_some(); + debug_assert!(!already_exists); + } + + pub(super) fn find_log(&self, log_id: u64) -> Option<RawLogEntry> { + self.log_table.get(&log_id).cloned() + } + + pub(super) fn append_log(&mut self, log_id: u64, tail: &RawLogTail) { + let log_entry = self.log_table.get_mut(&log_id).unwrap(); + log_entry.head.append(tail); + } + + pub fn delete_log(&mut self, log_id: u64) { + let _ = self.log_table.remove(&log_id); + // Leave chunk deallocation to lazy delete + } +} + +impl RawLogHead { + pub fn new() -> Self { + Self { + chunks: Vec::new(), + num_blocks: 0, + } + } + + pub fn append(&mut self, tail: &RawLogTail) { + // Update head + self.chunks.extend(tail.chunks.iter()); + self.num_blocks += tail.num_blocks; + // No need to update tail + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Persistent Edit +//////////////////////////////////////////////////////////////////////////////// + +/// A persistent edit to the state of `RawLogStore`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RawLogStoreEdit { + edit_table: HashMap<RawLogId, RawLogEdit>, +} + +/// The basic unit of a persistent edit to the state of `RawLogStore`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) enum RawLogEdit { + Create(RawLogCreate), + Append(RawLogAppend), + Delete, +} + +/// An edit that implies a log being created. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct RawLogCreate { + tail: RawLogTail, +} + +/// An edit that implies an existing log being appended. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct RawLogAppend { + tail: RawLogTail, +} + +/// A log tail contains chunk metadata of a log's TX-ongoing data. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct RawLogTail { + // The last chunk of the head. If it is partially filled + // (head_last_chunk_free_blocks > 0), then the tail should write to the + // free blocks in the last chunk of the head. + head_last_chunk_id: ChunkId, + head_last_chunk_free_blocks: u16, + // The chunks allocated and owned by the tail + chunks: Vec<ChunkId>, + // The total number of blocks in the tail, including the blocks written to + // the last chunk of head and those written to the chunks owned by the tail. + num_blocks: u32, +} + +impl RawLogStoreEdit { + /// Creates a new empty edit table. + pub fn new() -> Self { + Self { + edit_table: HashMap::new(), + } + } + + /// Records a log creation in the edit. + pub fn create_log(&mut self, new_log_id: RawLogId) { + let create_edit = RawLogEdit::Create(RawLogCreate::new()); + let edit_exists = self.edit_table.insert(new_log_id, create_edit); + debug_assert!(edit_exists.is_none()); + } + + /// Records a log being opened in the edit. + pub(super) fn open_log(&mut self, log_id: RawLogId, log_entry: &RawLogEntry) { + match self.edit_table.get(&log_id) { + None => { + // Insert an append edit + let tail = RawLogTail::from_head(&log_entry.head); + let append_edit = RawLogEdit::Append(RawLogAppend { tail }); + let edit_exists = self.edit_table.insert(log_id, append_edit); + debug_assert!(edit_exists.is_none()); + } + Some(edit) => { + // If `edit == create`, unreachable: there can't be a persistent log entry + // when the log is just created in an ongoing TX + if let RawLogEdit::Create(_) = edit { + unreachable!(); + } + // If `edit == append`, do nothing + // If `edit == delete`, panic + if let RawLogEdit::Delete = edit { + panic!("try to open a deleted log!"); + } + } + } + } + + /// Records a log deletion in the edit, returns the tail chunks of the deleted log. + pub fn delete_log(&mut self, log_id: RawLogId) -> Option<Vec<ChunkId>> { + match self.edit_table.insert(log_id, RawLogEdit::Delete) { + None => None, + Some(RawLogEdit::Create(create)) => { + // No need to panic in create + Some(create.tail.chunks.clone()) + } + Some(RawLogEdit::Append(append)) => { + // No need to panic in append (WAL case) + Some(append.tail.chunks.clone()) + } + Some(RawLogEdit::Delete) => panic!("try to delete a deleted log!"), + } + } + + pub fn is_log_created(&self, log_id: RawLogId) -> bool { + match self.edit_table.get(&log_id) { + Some(RawLogEdit::Create(_)) | Some(RawLogEdit::Append(_)) => true, + Some(RawLogEdit::Delete) | None => false, + } + } + + pub fn iter_created_logs(&self) -> impl Iterator<Item = RawLogId> + '_ { + self.edit_table + .iter() + .filter(|(_, edit)| matches!(edit, RawLogEdit::Create(_))) + .map(|(id, _)| *id) + } + + pub fn iter_deleted_logs(&self) -> impl Iterator<Item = RawLogId> + '_ { + self.edit_table + .iter() + .filter(|(_, edit)| matches!(edit, RawLogEdit::Delete)) + .map(|(id, _)| *id) + } + + pub fn is_empty(&self) -> bool { + self.edit_table.is_empty() + } +} + +impl Edit<RawLogStoreState> for RawLogStoreEdit { + fn apply_to(&self, state: &mut RawLogStoreState) { + for (&log_id, log_edit) in self.edit_table.iter() { + match log_edit { + RawLogEdit::Create(create) => { + let RawLogCreate { tail } = create; + state.create_log(log_id); + state.append_log(log_id, tail); + } + RawLogEdit::Append(append) => { + let RawLogAppend { tail } = append; + state.append_log(log_id, tail); + } + RawLogEdit::Delete => { + state.delete_log(log_id); + } + } + } + } +} + +impl RawLogCreate { + pub fn new() -> Self { + Self { + tail: RawLogTail::new(), + } + } +} + +impl RawLogTail { + pub fn new() -> Self { + Self { + head_last_chunk_id: 0, + head_last_chunk_free_blocks: 0, + chunks: Vec::new(), + num_blocks: 0, + } + } + + pub fn from_head(head: &RawLogHead) -> Self { + Self { + head_last_chunk_id: *head.chunks.last().unwrap_or(&0), + head_last_chunk_free_blocks: (head.chunks.len() * CHUNK_NBLOCKS + - head.num_blocks as usize) as _, + chunks: Vec::new(), + num_blocks: 0, + } + } +} + +impl TxData for RawLogStoreEdit {} + +#[cfg(test)] +mod tests { + use std::thread::{self, JoinHandle}; + + use super::*; + use crate::layers::{ + bio::{Buf, MemDisk}, + log::chunk::{CHUNK_NBLOCKS, CHUNK_SIZE}, + }; + + fn create_raw_log_store() -> Result<Arc<RawLogStore<MemDisk>>> { + let nchunks = 8; + let nblocks = nchunks * CHUNK_NBLOCKS; + let tx_provider = TxProvider::new(); + let chunk_alloc = ChunkAlloc::new(nchunks, tx_provider.clone()); + let mem_disk = MemDisk::create(nblocks)?; + Ok(RawLogStore::new(mem_disk, tx_provider, chunk_alloc)) + } + + fn find_persistent_log_entry( + log_store: &Arc<RawLogStore<MemDisk>>, + log_id: RawLogId, + ) -> Option<RawLogEntry> { + let state = log_store.state.lock(); + state.persistent.find_log(log_id) + } + + #[test] + fn raw_log_store_fns() -> Result<()> { + let raw_log_store = create_raw_log_store()?; + + // TX 1: create a new log and append contents (committed) + let mut tx = raw_log_store.new_tx(); + let res: Result<RawLogId> = tx.context(|| { + let new_log = raw_log_store.create_log()?; + let mut buf = Buf::alloc(4)?; + buf.as_mut_slice().fill(2u8); + new_log.append(buf.as_ref())?; + assert_eq!(new_log.nblocks(), 4); + Ok(new_log.id()) + }); + let log_id = res?; + tx.commit()?; + + let entry = find_persistent_log_entry(&raw_log_store, log_id).unwrap(); + assert_eq!(entry.head.num_blocks, 4); + + // TX 2: open the log, append contents then read (committed) + let mut tx = raw_log_store.new_tx(); + let res: Result<_> = tx.context(|| { + let log = raw_log_store.open_log(log_id, true)?; + + let mut buf = Buf::alloc(CHUNK_NBLOCKS)?; + buf.as_mut_slice().fill(5u8); + log.append(buf.as_ref())?; + + Ok(()) + }); + res?; + + let res: Result<_> = tx.context(|| { + let log = raw_log_store.open_log(log_id, true)?; + + let mut buf = Buf::alloc(4)?; + log.read(1 as BlockId, buf.as_mut())?; + assert_eq!(&buf.as_slice()[..3 * BLOCK_SIZE], &[2u8; 3 * BLOCK_SIZE]); + assert_eq!(&buf.as_slice()[3 * BLOCK_SIZE..], &[5u8; BLOCK_SIZE]); + + Ok(()) + }); + res?; + tx.commit()?; + + let entry = find_persistent_log_entry(&raw_log_store, log_id).unwrap(); + assert_eq!(entry.head.num_blocks, 1028); + + // TX 3: delete the log (committed) + let mut tx = raw_log_store.new_tx(); + let res: Result<_> = tx.context(|| raw_log_store.delete_log(log_id)); + res?; + tx.commit()?; + + let entry_opt = find_persistent_log_entry(&raw_log_store, log_id); + assert!(entry_opt.is_none()); + + // TX 4: create a new log (aborted) + let mut tx = raw_log_store.new_tx(); + let res: Result<_> = tx.context(|| { + let new_log = raw_log_store.create_log()?; + Ok(new_log.id()) + }); + let new_log_id = res?; + tx.abort(); + + let entry_opt = find_persistent_log_entry(&raw_log_store, new_log_id); + assert!(entry_opt.is_none()); + + Ok(()) + } + + #[test] + fn raw_log_deletion() -> Result<()> { + let raw_log_store = create_raw_log_store()?; + + // Create a new log and append contents + let mut tx = raw_log_store.new_tx(); + let content = 5_u8; + let res: Result<_> = tx.context(|| { + let new_log = raw_log_store.create_log()?; + let mut buf = Buf::alloc(1)?; + buf.as_mut_slice().fill(content); + new_log.append(buf.as_ref())?; + Ok(new_log.id()) + }); + let log_id = res?; + tx.commit()?; + + // Concurrently open, read then delete the log + let handlers = (0..16) + .map(|_| { + let raw_log_store = raw_log_store.clone(); + thread::spawn(move || -> Result<()> { + let mut tx = raw_log_store.new_tx(); + println!( + "TX[{:?}] executed on thread[{:?}]", + tx.id(), + crate::os::CurrentThread::id() + ); + let _ = tx.context(|| { + let log = raw_log_store.open_log(log_id, false)?; + let mut buf = Buf::alloc(1)?; + log.read(0 as BlockId, buf.as_mut())?; + assert_eq!(buf.as_slice(), &[content; BLOCK_SIZE]); + raw_log_store.delete_log(log_id) + }); + tx.commit() + }) + }) + .collect::<Vec<JoinHandle<Result<()>>>>(); + for handler in handlers { + handler.join().unwrap()?; + } + + // The log has already been deleted + let mut tx = raw_log_store.new_tx(); + let _ = tx.context(|| { + let res = raw_log_store.open_log(log_id, false).map(|_| ()); + res.expect_err("result must be NotFound"); + }); + tx.commit() + } +} diff --git a/kernel/comps/mlsdisk/src/layers/3-log/tx_log.rs b/kernel/comps/mlsdisk/src/layers/3-log/tx_log.rs new file mode 100644 index 00000000..420fc3cd --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/3-log/tx_log.rs @@ -0,0 +1,1491 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A store of transactional logs. +//! +//! `TxLogStore<D>` supports creating, deleting, listing, reading, +//! and writing `TxLog<D>`s within transactions. Each `TxLog<D>` +//! is uniquely identified by its ID (`TxLogId`). Writing to a TX log +//! is append only. TX logs are categorized into pre-determined buckets. +//! +//! File content of `TxLog<D>` is stored securely using a `CryptoLog<RawLog<D>>`, +//! whose storage space is backed by untrusted log `RawLog<D>`, +//! whose host blocks are managed by `ChunkAlloc`. The whole untrusted +//! host disk that `TxLogSore<D>` used is represented by a `BlockSet`. +//! +//! # Examples +//! +//! TX logs are manipulated and accessed within transactions. +//! +//! ``` +//! fn create_and_read_log<D: BlockSet>( +//! tx_log_store: &TxLogStore<D>, +//! bucket: &str +//! ) -> Result<()> { +//! let content = 5_u8; +//! +//! // TX 1: Create then write a new log +//! let mut tx = tx_log_store.new_tx(); +//! let res: Result<_> = tx.context(|| { +//! let new_log = tx_log_store.create_log(bucket)?; +//! let mut buf = Buf::alloc(1)?; +//! buf.as_mut_slice().fill(content); +//! new_log.append(buf.as_ref()) +//! }); +//! if res.is_err() { +//! tx.abort(); +//! } +//! tx.commit()?; +//! +//! // TX 2: Open then read the created log +//! let mut tx = tx_log_store.new_tx(); +//! let res: Result<_> = tx.context(|| { +//! let log = tx_log_store.open_log_in(bucket)?; +//! let mut buf = Buf::alloc(1)?; +//! log.read(0 as BlockId, buf.as_mut())?; +//! assert_eq!(buf.as_slice()[0], content); +//! Ok(()) +//! }); +//! if res.is_err() { +//! tx.abort(); +//! } +//! tx.commit() +//! } +//! ``` +//! +//! `TxLogStore<D>`'s API is designed to be a limited POSIX FS +//! and must be called within transactions (`Tx`). It mitigates user burden by +//! minimizing the odds of conflicts among TXs: +//! 1) Prohibiting concurrent TXs from opening the same log for +//! writing (no write conflicts); +//! 2) Implementing lazy log deletion to avoid interference with +//! other TXs utilizing the log (no deletion conflicts); +//! 3) Identifying logs by system-generated IDs (no name conflicts). +use core::{ + any::Any, + sync::atomic::{AtomicBool, Ordering}, +}; + +use lru::LruCache; +use ostd_pod::Pod; +use serde::{Deserialize, Serialize}; + +use self::journaling::{AllEdit, AllState, Journal, JournalCompactPolicy}; +use super::{ + chunk::{ChunkAlloc, ChunkAllocEdit, ChunkAllocState}, + raw_log::{RawLog, RawLogId, RawLogStore, RawLogStoreEdit, RawLogStoreState}, +}; +use crate::{ + layers::{ + bio::{BlockId, BlockSet, Buf, BufMut, BufRef}, + crypto::{CryptoLog, NodeCache, RootMhtMeta}, + edit::{CompactPolicy, Edit, EditJournal, EditJournalMeta}, + log::chunk::CHUNK_NBLOCKS, + }, + os::{AeadKey as Key, HashMap, HashSet, Mutex, Skcipher, SkcipherIv, SkcipherKey}, + prelude::*, + tx::{CurrentTx, TxData, TxId, TxProvider}, + util::LazyDelete, +}; + +/// The ID of a TX log. +pub type TxLogId = RawLogId; +/// The name of a TX log bucket. +type BucketName = String; + +/// A store of transactional logs. +/// +/// Disk layout: +/// ```text +/// ------------------------------------------------------ +/// | Superblock | RawLogStore region | Journal region | +/// ------------------------------------------------------ +/// ``` +#[derive(Clone)] +pub struct TxLogStore<D> { + state: Arc<Mutex<State>>, + raw_log_store: Arc<RawLogStore<D>>, + journal: Arc<Mutex<Journal<D>>>, + superblock: Superblock, + root_key: Key, + raw_disk: D, + tx_provider: Arc<TxProvider>, +} + +/// Superblock of `TxLogStore`. +#[repr(C)] +#[derive(Clone, Copy, Pod, Debug)] +pub struct Superblock { + journal_area_meta: EditJournalMeta, + chunk_area_nblocks: usize, + magic: u64, +} +const MAGIC_NUMBER: u64 = 0x1130_0821; + +impl<D: BlockSet + 'static> TxLogStore<D> { + /// Formats the disk to create a new instance of `TxLogStore`, + /// with the given root key. + pub fn format(disk: D, root_key: Key) -> Result<Self> { + let total_nblocks = disk.nblocks(); + let (log_store_nblocks, journal_nblocks) = + Self::calc_store_and_journal_nblocks(total_nblocks); + let nchunks = log_store_nblocks / CHUNK_NBLOCKS; + + let log_store_area = disk.subset(1..1 + log_store_nblocks)?; + let journal_area = + disk.subset(1 + log_store_nblocks..1 + log_store_nblocks + journal_nblocks)?; + + let tx_provider = TxProvider::new(); + + let journal = { + let all_state = AllState { + chunk_alloc: ChunkAllocState::new_in_journal(nchunks), + raw_log_store: RawLogStoreState::new(), + tx_log_store: TxLogStoreState::new(), + }; + let state_max_nbytes = 1048576; // TBD + let compaction_policy = + JournalCompactPolicy::new::<D>(journal_area.nblocks(), state_max_nbytes); + Arc::new(Mutex::new(Journal::format( + journal_area, + all_state, + state_max_nbytes, + compaction_policy, + )?)) + }; + Self::register_commit_handler_for_journal(&journal, &tx_provider); + + let chunk_alloc = ChunkAlloc::new(nchunks, tx_provider.clone()); + let raw_log_store = RawLogStore::new(log_store_area, tx_provider.clone(), chunk_alloc); + let tx_log_store_state = TxLogStoreState::new(); + + let superblock = Superblock { + journal_area_meta: journal.lock().meta(), + chunk_area_nblocks: log_store_nblocks, + magic: MAGIC_NUMBER, + }; + superblock.persist(&disk.subset(0..1)?, &root_key)?; + + Ok(Self::from_parts( + tx_log_store_state, + raw_log_store, + journal, + superblock, + root_key, + disk, + tx_provider, + )) + } + + /// Calculate the number of blocks required for the store and the journal. + fn calc_store_and_journal_nblocks(total_nblocks: usize) -> (usize, usize) { + let log_store_nblocks = { + let nblocks = (total_nblocks - 1) * 9 / 10; + align_down(nblocks, CHUNK_NBLOCKS) + }; + let journal_nblocks = total_nblocks - 1 - log_store_nblocks; + debug_assert!(1 + log_store_nblocks + journal_nblocks <= total_nblocks); + (log_store_nblocks, journal_nblocks) // TBD + } + + fn register_commit_handler_for_journal( + journal: &Arc<Mutex<Journal<D>>>, + tx_provider: &Arc<TxProvider>, + ) { + let journal = journal.clone(); + tx_provider.register_commit_handler({ + move |current: CurrentTx<'_>| { + let mut journal = journal.lock(); + current.data_with(|tx_log_edit: &TxLogStoreEdit| { + if tx_log_edit.is_empty() { + return; + } + journal.add(AllEdit::from_tx_log_edit(tx_log_edit)); + }); + current.data_with(|raw_log_edit: &RawLogStoreEdit| { + if raw_log_edit.is_empty() { + return; + } + journal.add(AllEdit::from_raw_log_edit(raw_log_edit)); + }); + current.data_with(|chunk_edit: &ChunkAllocEdit| { + if chunk_edit.is_empty() { + return; + } + journal.add(AllEdit::from_chunk_edit(chunk_edit)); + }); + journal.commit(); + } + }); + } + + /// Recovers an existing `TxLogStore` from a disk using the given key. + pub fn recover(disk: D, root_key: Key) -> Result<Self> { + let superblock = Superblock::open(&disk.subset(0..1)?, &root_key)?; + if disk.nblocks() < superblock.total_nblocks() { + return_errno_with_msg!(OutOfDisk, "given disk lacks space for recovering"); + } + + let tx_provider = TxProvider::new(); + + let journal = { + let journal_area_meta = &superblock.journal_area_meta; + let journal_area = disk.subset( + 1 + superblock.chunk_area_nblocks + ..1 + superblock.chunk_area_nblocks + journal_area_meta.total_nblocks(), + )?; + let compaction_policy = JournalCompactPolicy::from_meta(journal_area_meta); + Arc::new(Mutex::new(Journal::recover( + journal_area, + journal_area_meta, + compaction_policy, + )?)) + }; + Self::register_commit_handler_for_journal(&journal, &tx_provider); + + let AllState { + chunk_alloc, + raw_log_store, + tx_log_store, + } = journal.lock().state().clone(); + + let chunk_alloc = ChunkAlloc::from_parts(chunk_alloc, tx_provider.clone()); + let chunk_area = disk.subset(1..1 + superblock.chunk_area_nblocks)?; + let raw_log_store = + RawLogStore::from_parts(raw_log_store, chunk_area, chunk_alloc, tx_provider.clone()); + let tx_log_store = TxLogStore::from_parts( + tx_log_store, + raw_log_store, + journal, + superblock, + root_key, + disk, + tx_provider, + ); + + Ok(tx_log_store) + } + + /// Constructs a `TxLogStore` from its parts. + fn from_parts( + state: TxLogStoreState, + raw_log_store: Arc<RawLogStore<D>>, + journal: Arc<Mutex<Journal<D>>>, + superblock: Superblock, + root_key: Key, + raw_disk: D, + tx_provider: Arc<TxProvider>, + ) -> Self { + let new_self = { + // Prepare lazy deletes and log caches first from persistent state + let (lazy_deletes, log_caches) = { + let (mut delete_table, mut cache_table) = (HashMap::new(), HashMap::new()); + for log_id in state.list_all_logs() { + Self::add_lazy_delete(log_id, &mut delete_table, &raw_log_store); + cache_table.insert(log_id, Arc::new(CryptoLogCache::new(log_id, &tx_provider))); + } + (delete_table, cache_table) + }; + + Self { + state: Arc::new(Mutex::new(State::new(state, lazy_deletes, log_caches))), + raw_log_store, + journal: journal.clone(), + superblock, + root_key, + raw_disk, + tx_provider: tx_provider.clone(), + } + }; + + // TX data + tx_provider.register_data_initializer(Box::new(TxLogStoreEdit::new)); + tx_provider.register_data_initializer(Box::new(OpenLogTable::<D>::new)); + tx_provider.register_data_initializer(Box::new(OpenLogCache::new)); + + // Precommit handler + tx_provider.register_precommit_handler({ + move |mut current: CurrentTx<'_>| { + // Do I/O in the pre-commit phase. If any I/O error occurred, + // the TX would be aborted. + Self::update_dirty_log_metas(&mut current) + } + }); + + // Commit handler for log store + tx_provider.register_commit_handler({ + let state = new_self.state.clone(); + let raw_log_store = new_self.raw_log_store.clone(); + move |mut current: CurrentTx<'_>| { + current.data_with(|store_edit: &TxLogStoreEdit| { + if store_edit.is_empty() { + return; + } + + let mut state = state.lock(); + state.apply(store_edit); + + Self::add_lazy_deletes_for_created_logs(&mut state, store_edit, &raw_log_store); + }); + + let mut state = state.lock(); + Self::apply_log_caches(&mut state, &mut current); + Self::do_lazy_deletion(&mut state, ¤t); + } + }); + + new_self + } + + /// Record all dirty logs in the current TX. + fn update_dirty_log_metas(current_tx: &mut CurrentTx<'_>) -> Result<()> { + let dirty_logs: Vec<(TxLogId, Arc<TxLogInner<D>>)> = + current_tx.data_with(|open_log_table: &OpenLogTable<D>| { + open_log_table + .open_table + .iter() + .filter_map(|(id, inner_log)| { + if inner_log.is_dirty.load(Ordering::Acquire) { + Some((*id, inner_log.clone())) + } else { + None + } + }) + .collect() + }); + + for (log_id, inner_log) in dirty_logs { + let crypto_log = &inner_log.crypto_log; + crypto_log.flush()?; + + current_tx.data_mut_with(|store_edit: &mut TxLogStoreEdit| { + store_edit.update_log_meta((log_id, crypto_log.root_meta().unwrap())) + }); + } + Ok(()) + } + + fn add_lazy_delete( + log_id: TxLogId, + delete_table: &mut HashMap<TxLogId, Arc<LazyDelete<TxLogId>>>, + raw_log_store: &Arc<RawLogStore<D>>, + ) { + let raw_log_store = raw_log_store.clone(); + delete_table.insert( + log_id, + Arc::new(LazyDelete::new(log_id, move |log_id| { + raw_log_store.delete_log(*log_id).unwrap(); + })), + ); + } + + fn add_lazy_deletes_for_created_logs( + state: &mut State, + edit: &TxLogStoreEdit, + raw_log_store: &Arc<RawLogStore<D>>, + ) { + for &log_id in edit.iter_created_logs() { + if state.lazy_deletes.contains_key(&log_id) { + continue; + } + + Self::add_lazy_delete(log_id, &mut state.lazy_deletes, raw_log_store); + } + } + + fn do_lazy_deletion(state: &mut State, current_tx: &CurrentTx<'_>) { + let deleted_logs = current_tx.data_with(|edit: &TxLogStoreEdit| { + edit.iter_deleted_logs().cloned().collect::<Vec<_>>() + }); + + for deleted_log_id in deleted_logs { + let Some(lazy_delete) = state.lazy_deletes.remove(&deleted_log_id) else { + // Other concurrent TXs have deleted the same log + continue; + }; + LazyDelete::delete(&lazy_delete); + + // Also remove the cache by the way + state.log_caches.remove(&deleted_log_id); + } + } + + fn apply_log_caches(state: &mut State, current_tx: &mut CurrentTx<'_>) { + // Apply per-TX log cache + current_tx.data_mut_with(|open_cache_table: &mut OpenLogCache| { + if open_cache_table.open_table.is_empty() { + return; + } + + // TODO: May need performance improvement + let log_caches = &mut state.log_caches; + for (log_id, open_cache) in open_cache_table.open_table.iter_mut() { + let log_cache = log_caches.get_mut(log_id).unwrap(); + let mut cache_inner = log_cache.inner.lock(); + if cache_inner.lru_cache.is_empty() { + core::mem::swap(&mut cache_inner.lru_cache, &mut open_cache.lru_cache); + return; + } + + open_cache.lru_cache.iter().for_each(|(&pos, node)| { + cache_inner.lru_cache.put(pos, node.clone()); + }); + } + }); + } + + /// Creates a new, empty log in a bucket. + /// + /// On success, the returned `TxLog` is opened in the appendable mode. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn create_log(&self, bucket: &str) -> Result<Arc<TxLog<D>>> { + let raw_log = self.raw_log_store.create_log()?; + let log_id = raw_log.id(); + + let log_cache = Arc::new(CryptoLogCache::new(log_id, &self.tx_provider)); + self.state + .lock() + .log_caches + .insert(log_id, log_cache.clone()); + let key = Key::random(); + let crypto_log = CryptoLog::new(raw_log, key, log_cache); + + let mut current_tx = self.tx_provider.current(); + let bucket = bucket.to_string(); + let inner_log = Arc::new(TxLogInner { + log_id, + tx_id: current_tx.id(), + bucket: bucket.clone(), + crypto_log, + lazy_delete: None, + is_dirty: AtomicBool::new(false), + }); + + current_tx.data_mut_with(|store_edit: &mut TxLogStoreEdit| { + store_edit.create_log(log_id, bucket, key); + }); + + current_tx.data_mut_with(|open_log_table: &mut OpenLogTable<D>| { + let _ = open_log_table.open_table.insert(log_id, inner_log.clone()); + }); + + current_tx.data_mut_with(|open_cache_table: &mut OpenLogCache| { + let _ = open_cache_table + .open_table + .insert(log_id, CacheInner::new()); + }); + + Ok(Arc::new(TxLog { + inner_log, + tx_provider: self.tx_provider.clone(), + can_append: true, + })) + } + + /// Opens the log of a given ID. + /// + /// For any log at any time, there can be at most one TX that opens the log + /// in the appendable mode. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn open_log(&self, log_id: TxLogId, can_append: bool) -> Result<Arc<TxLog<D>>> { + let mut current_tx = self.tx_provider.current(); + let inner_log = self.open_inner_log(log_id, can_append, &mut current_tx)?; + let tx_log = TxLog::new(inner_log, self.tx_provider.clone(), can_append); + Ok(Arc::new(tx_log)) + } + + fn open_inner_log( + &self, + log_id: TxLogId, + can_append: bool, + current_tx: &mut CurrentTx<'_>, + ) -> Result<Arc<TxLogInner<D>>> { + // Fast path: the log has been opened in this TX + let opened_log_opt = current_tx.data_with(|open_log_table: &OpenLogTable<D>| { + open_log_table.open_table.get(&log_id).cloned() + }); + if let Some(inner_log) = opened_log_opt { + return Ok(inner_log); + } + + // Slow path: the first time a log is to be opened in a TX + let state = self.state.lock(); + // Must check lazy deletes first in case concurrent deletion + let lazy_delete = state + .lazy_deletes + .get(&log_id) + .ok_or(Error::with_msg(NotFound, "log has been deleted"))? + .clone(); + let log_entry = { + // The log must exist in state... + let log_entry: &TxLogEntry = state.persistent.find_log(log_id)?; + // ...and not be marked deleted by edit + let is_deleted = current_tx + .data_with(|store_edit: &TxLogStoreEdit| store_edit.is_log_deleted(log_id)); + if is_deleted { + return_errno_with_msg!(NotFound, "log has been marked deleted"); + } + log_entry + }; + + // Prepare cache before opening `CryptoLog` + current_tx.data_mut_with(|open_cache_table: &mut OpenLogCache| { + let _ = open_cache_table + .open_table + .insert(log_id, CacheInner::new()); + }); + + let bucket = log_entry.bucket.clone(); + let crypto_log = { + let raw_log = self.raw_log_store.open_log(log_id, can_append)?; + let key = log_entry.key; + let root_meta = log_entry.root_mht.clone(); + let cache = state.log_caches.get(&log_id).unwrap().clone(); + CryptoLog::open(raw_log, key, root_meta, cache)? + }; + + let root_mht = crypto_log.root_meta().unwrap(); + let inner_log = Arc::new(TxLogInner { + log_id, + tx_id: current_tx.id(), + bucket, + crypto_log, + lazy_delete: Some(lazy_delete), + is_dirty: AtomicBool::new(false), + }); + + current_tx.data_mut_with(|open_log_table: &mut OpenLogTable<D>| { + open_log_table.open_table.insert(log_id, inner_log.clone()); + }); + + if can_append { + current_tx.data_mut_with(|store_edit: &mut TxLogStoreEdit| { + store_edit.append_log(log_id, root_mht); + }); + } + Ok(inner_log) + } + + /// Lists the IDs of all logs in a bucket. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn list_logs_in(&self, bucket_name: &str) -> Result<Vec<TxLogId>> { + let state = self.state.lock(); + let mut log_id_set = state.persistent.list_logs_in(bucket_name)?; + let current_tx = self.tx_provider.current(); + current_tx.data_with(|store_edit: &TxLogStoreEdit| { + for (&log_id, log_edit) in &store_edit.edit_table { + match log_edit { + TxLogEdit::Create(create) => { + if create.bucket == bucket_name { + log_id_set.insert(log_id); + } + } + TxLogEdit::Append(_) | TxLogEdit::Move(_) => {} + TxLogEdit::Delete => { + if log_id_set.contains(&log_id) { + log_id_set.remove(&log_id); + } + } + } + } + }); + let log_id_vec = log_id_set.into_iter().collect::<Vec<_>>(); + Ok(log_id_vec) + } + + /// Opens the log with the maximum ID in a bucket. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn open_log_in(&self, bucket: &str) -> Result<Arc<TxLog<D>>> { + let log_ids = self.list_logs_in(bucket)?; + let max_log_id = log_ids + .iter() + .max() + .ok_or(Error::with_msg(NotFound, "tx log not found"))?; + self.open_log(*max_log_id, false) + } + + /// Checks whether the log of a given log ID exists or not. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn contains_log(&self, log_id: TxLogId) -> bool { + let state = self.state.lock(); + let current_tx = self.tx_provider.current(); + self.do_contain_log(log_id, &state, ¤t_tx) + } + + fn do_contain_log(&self, log_id: TxLogId, state: &State, current_tx: &CurrentTx<'_>) -> bool { + if state.persistent.contains_log(log_id) { + current_tx.data_with(|store_edit: &TxLogStoreEdit| !store_edit.is_log_deleted(log_id)) + } else { + current_tx.data_with(|store_edit: &TxLogStoreEdit| store_edit.is_log_created(log_id)) + } + } + + /// Deletes the log of a given ID. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn delete_log(&self, log_id: TxLogId) -> Result<()> { + let mut current_tx = self.tx_provider.current(); + + current_tx.data_mut_with(|open_log_table: &mut OpenLogTable<D>| { + let _ = open_log_table.open_table.remove(&log_id); + }); + + current_tx.data_mut_with(|open_cache_table: &mut OpenLogCache| { + let _ = open_cache_table.open_table.remove(&log_id); + }); + + if !self.do_contain_log(log_id, &self.state.lock(), ¤t_tx) { + return_errno_with_msg!(NotFound, "target deleted log not found"); + } + + current_tx.data_mut_with(|store_edit: &mut TxLogStoreEdit| { + store_edit.delete_log(log_id); + }); + + // Do lazy delete in commit phase + Ok(()) + } + + /// Moves the log of a given ID from one bucket to another. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn move_log(&self, log_id: TxLogId, from_bucket: &str, to_bucket: &str) -> Result<()> { + let mut current_tx = self.tx_provider.current(); + + current_tx.data_mut_with(|open_log_table: &mut OpenLogTable<D>| { + if let Some(log) = open_log_table.open_table.get(&log_id) { + debug_assert!(log.bucket == from_bucket && !log.is_dirty.load(Ordering::Acquire)) + } + }); + + current_tx.data_mut_with(|store_edit: &mut TxLogStoreEdit| { + store_edit.move_log(log_id, from_bucket, to_bucket); + }); + + Ok(()) + } + + /// Returns the root key. + pub fn root_key(&self) -> &Key { + &self.root_key + } + + /// Creates a new transaction. + pub fn new_tx(&self) -> CurrentTx<'_> { + self.tx_provider.new_tx() + } + + /// Returns the current transaction. + pub fn current_tx(&self) -> CurrentTx<'_> { + self.tx_provider.current() + } + + /// Syncs all the data managed by `TxLogStore` for persistence. + pub fn sync(&self) -> Result<()> { + self.raw_log_store.sync().unwrap(); + self.journal.lock().flush().unwrap(); + + self.raw_disk.flush() + } +} + +impl<D: BlockSet + 'static> Debug for TxLogStore<D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = self.state.lock(); + f.debug_struct("TxLogStore") + .field("persistent_log_table", &state.persistent.log_table) + .field("persistent_bucket_table", &state.persistent.bucket_table) + .field("raw_log_store", &self.raw_log_store) + .field("root_key", &self.root_key) + .finish() + } +} + +impl Superblock { + const SUPERBLOCK_SIZE: usize = core::mem::size_of::<Superblock>(); + + /// Returns the total number of blocks occupied by the `TxLogStore`. + pub fn total_nblocks(&self) -> usize { + self.journal_area_meta.total_nblocks() + self.chunk_area_nblocks + } + + /// Reads the `Superblock` on the disk with the given root key. + pub fn open<D: BlockSet>(disk: &D, root_key: &Key) -> Result<Self> { + let mut cipher = Buf::alloc(1)?; + disk.read(0, cipher.as_mut())?; + let mut plain = Buf::alloc(1)?; + Skcipher::new().decrypt( + cipher.as_slice(), + &Self::derive_skcipher_key(root_key), + &SkcipherIv::new_zeroed(), + plain.as_mut_slice(), + )?; + + let superblock = Superblock::from_bytes(&plain.as_slice()[..Self::SUPERBLOCK_SIZE]); + if superblock.magic != MAGIC_NUMBER { + Err(Error::with_msg(InvalidArgs, "open superblock failed")) + } else { + Ok(superblock) + } + } + + /// Persists the `Superblock` on the disk with the given root key. + pub fn persist<D: BlockSet>(&self, disk: &D, root_key: &Key) -> Result<()> { + let mut plain = Buf::alloc(1)?; + plain.as_mut_slice()[..Self::SUPERBLOCK_SIZE].copy_from_slice(self.as_bytes()); + let mut cipher = Buf::alloc(1)?; + Skcipher::new().encrypt( + plain.as_slice(), + &Self::derive_skcipher_key(root_key), + &SkcipherIv::new_zeroed(), + cipher.as_mut_slice(), + )?; + disk.write(0, cipher.as_ref()) + } + + fn derive_skcipher_key(root_key: &Key) -> SkcipherKey { + SkcipherKey::from_bytes(root_key.as_bytes()) + } +} + +/// A transactional log. +#[derive(Clone)] +pub struct TxLog<D> { + inner_log: Arc<TxLogInner<D>>, + tx_provider: Arc<TxProvider>, + can_append: bool, +} + +/// Inner structures of a transactional log. +struct TxLogInner<D> { + log_id: TxLogId, + tx_id: TxId, + bucket: BucketName, + crypto_log: CryptoLog<RawLog<D>>, + lazy_delete: Option<Arc<LazyDelete<TxLogId>>>, + is_dirty: AtomicBool, +} + +impl<D: BlockSet + 'static> TxLog<D> { + fn new(inner_log: Arc<TxLogInner<D>>, tx_provider: Arc<TxProvider>, can_append: bool) -> Self { + Self { + inner_log, + tx_provider, + can_append, + } + } + + /// Returns the log ID. + pub fn id(&self) -> TxLogId { + self.inner_log.log_id + } + + /// Returns the TX ID. + pub fn tx_id(&self) -> TxId { + self.inner_log.tx_id + } + + /// Returns the bucket that this log belongs to. + pub fn bucket(&self) -> &str { + &self.inner_log.bucket + } + + /// Returns whether the log is opened in the appendable mode. + pub fn can_append(&self) -> bool { + self.can_append + } + + /// Reads one or multiple data blocks at a specified position. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn read(&self, pos: BlockId, buf: BufMut) -> Result<()> { + debug_assert_eq!(self.tx_id(), self.tx_provider.current().id()); + + self.inner_log.crypto_log.read(pos, buf) + } + + /// Appends one or multiple data blocks at the end. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn append(&self, buf: BufRef) -> Result<()> { + debug_assert_eq!(self.tx_id(), self.tx_provider.current().id()); + + if !self.can_append { + return_errno_with_msg!(PermissionDenied, "tx log not in append mode"); + } + + self.inner_log.is_dirty.store(true, Ordering::Release); + self.inner_log.crypto_log.append(buf) + } + + /// Returns the length of the log in unit of block. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn nblocks(&self) -> usize { + debug_assert_eq!(self.tx_id(), self.tx_provider.current().id()); + + self.inner_log.crypto_log.nblocks() + } +} + +impl<D: BlockSet + 'static> Debug for TxLog<D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TxLog") + .field("id", &self.inner_log.log_id) + .field("bucket", &self.inner_log.bucket) + .field("crypto_log", &self.inner_log.crypto_log) + .field("is_dirty", &self.inner_log.is_dirty.load(Ordering::Acquire)) + .finish() + } +} + +/// Node cache for `CryptoLog` in a transactional log. +pub struct CryptoLogCache { + inner: Mutex<CacheInner>, + log_id: TxLogId, + tx_provider: Arc<TxProvider>, +} + +pub(super) struct CacheInner { + pub lru_cache: LruCache<BlockId, Arc<dyn Any + Send + Sync>>, +} + +impl CryptoLogCache { + fn new(log_id: TxLogId, tx_provider: &Arc<TxProvider>) -> Self { + Self { + inner: Mutex::new(CacheInner::new()), + log_id, + tx_provider: tx_provider.clone(), + } + } +} + +impl NodeCache for CryptoLogCache { + fn get(&self, pos: BlockId) -> Option<Arc<dyn Any + Send + Sync>> { + let mut current = self.tx_provider.current(); + + let value_opt = current.data_mut_with(|open_cache_table: &mut OpenLogCache| { + open_cache_table + .open_table + .get_mut(&self.log_id) + .and_then(|open_cache| open_cache.lru_cache.get(&pos).cloned()) + }); + if value_opt.is_some() { + return value_opt; + } + + let mut inner = self.inner.lock(); + inner.lru_cache.get(&pos).cloned() + } + + fn put( + &self, + pos: BlockId, + value: Arc<dyn Any + Send + Sync>, + ) -> Option<Arc<dyn Any + Send + Sync>> { + let mut current = self.tx_provider.current(); + + current.data_mut_with(|open_cache_table: &mut OpenLogCache| { + debug_assert!(open_cache_table.open_table.contains_key(&self.log_id)); + let open_cache = open_cache_table.open_table.get_mut(&self.log_id).unwrap(); + open_cache.lru_cache.put(pos, value) + }) + } +} + +impl CacheInner { + pub fn new() -> Self { + // TODO: Give the cache a bound then test cache hit rate + Self { + lru_cache: LruCache::unbounded(), + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Persistent State +//////////////////////////////////////////////////////////////////////////////// + +/// The volatile and persistent state of a `TxLogStore`. +struct State { + persistent: TxLogStoreState, + lazy_deletes: HashMap<TxLogId, Arc<LazyDelete<TxLogId>>>, + log_caches: HashMap<TxLogId, Arc<CryptoLogCache>>, +} + +/// The persistent state of a `TxLogStore`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TxLogStoreState { + log_table: HashMap<TxLogId, TxLogEntry>, + bucket_table: HashMap<BucketName, Bucket>, +} + +/// A log entry implies the persistent state of the tx log. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TxLogEntry { + pub bucket: BucketName, + pub key: Key, + pub root_mht: RootMhtMeta, +} + +/// A bucket contains a set of logs which have the same name. +#[derive(Clone, Debug, Serialize, Deserialize)] +struct Bucket { + log_ids: HashSet<TxLogId>, +} + +impl State { + pub fn new( + persistent: TxLogStoreState, + lazy_deletes: HashMap<TxLogId, Arc<LazyDelete<TxLogId>>>, + log_caches: HashMap<TxLogId, Arc<CryptoLogCache>>, + ) -> Self { + Self { + persistent, + lazy_deletes, + log_caches, + } + } + + pub fn apply(&mut self, edit: &TxLogStoreEdit) { + edit.apply_to(&mut self.persistent); + } +} + +impl TxLogStoreState { + pub fn new() -> Self { + Self { + log_table: HashMap::new(), + bucket_table: HashMap::new(), + } + } + + pub fn create_log( + &mut self, + new_log_id: TxLogId, + bucket: BucketName, + key: Key, + root_mht: RootMhtMeta, + ) { + let already_exists = self.log_table.insert( + new_log_id, + TxLogEntry { + bucket: bucket.clone(), + key, + root_mht, + }, + ); + debug_assert!(already_exists.is_none()); + + match self.bucket_table.get_mut(&bucket) { + Some(bucket) => { + bucket.log_ids.insert(new_log_id); + } + None => { + self.bucket_table.insert( + bucket, + Bucket { + log_ids: HashSet::from([new_log_id]), + }, + ); + } + } + } + + pub fn find_log(&self, log_id: TxLogId) -> Result<&TxLogEntry> { + self.log_table + .get(&log_id) + .ok_or(Error::with_msg(NotFound, "log entry not found")) + } + + pub fn list_logs_in(&self, bucket: &str) -> Result<HashSet<TxLogId>> { + let bucket = self + .bucket_table + .get(bucket) + .ok_or(Error::with_msg(NotFound, "bucket not found"))?; + Ok(bucket.log_ids.clone()) + } + + pub fn list_all_logs(&self) -> impl Iterator<Item = TxLogId> + '_ { + self.log_table.iter().map(|(id, _)| *id) + } + + pub fn contains_log(&self, log_id: TxLogId) -> bool { + self.log_table.contains_key(&log_id) + } + + pub fn append_log(&mut self, log_id: TxLogId, root_mht: RootMhtMeta) { + let entry = self.log_table.get_mut(&log_id).unwrap(); + entry.root_mht = root_mht; + } + + pub fn delete_log(&mut self, log_id: TxLogId) { + // Do not check the result because concurrent TXs + // may decide to delete the same logs + if let Some(entry) = self.log_table.remove(&log_id) { + self.bucket_table + .get_mut(&entry.bucket) + .map(|bucket| bucket.log_ids.remove(&log_id)); + } + } + + pub fn move_log(&mut self, log_id: TxLogId, from_bucket: &str, to_bucket: &str) { + let entry = self.log_table.get_mut(&log_id).unwrap(); + debug_assert_eq!(entry.bucket, from_bucket); + let to_bucket = to_bucket.to_string(); + entry.bucket = to_bucket.clone(); + + self.bucket_table + .get_mut(from_bucket) + .map(|bucket| bucket.log_ids.remove(&log_id)) + .expect("`from_bucket` '{from_bucket:?}' must exist"); + + if let Some(bucket) = self.bucket_table.get_mut(&to_bucket) { + bucket.log_ids.insert(log_id); + } else { + let _ = self.bucket_table.insert( + to_bucket, + Bucket { + log_ids: HashSet::from([log_id]), + }, + ); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Persistent Edit +//////////////////////////////////////////////////////////////////////////////// + +/// A persistent edit to the state of `TxLogStore`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TxLogStoreEdit { + edit_table: HashMap<TxLogId, TxLogEdit>, +} + +/// Used for per-TX data, track open logs in memory +pub(super) struct OpenLogTable<D> { + open_table: HashMap<TxLogId, Arc<TxLogInner<D>>>, +} + +/// Used for per-TX data, track open log caches in memory +pub(super) struct OpenLogCache { + open_table: HashMap<TxLogId, CacheInner>, +} + +/// The basic unit of a persistent edit to the state of `TxLogStore`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) enum TxLogEdit { + Create(TxLogCreate), + Append(TxLogAppend), + Delete, + Move(TxLogMove), +} + +/// An edit that implies a log being created. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct TxLogCreate { + bucket: BucketName, + key: Key, + root_mht: Option<RootMhtMeta>, +} + +/// An edit that implies an existing log being appended. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct TxLogAppend { + root_mht: RootMhtMeta, +} + +/// An edit that implies a log being moved from one bucket to another. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct TxLogMove { + from: BucketName, + to: BucketName, +} + +impl TxLogStoreEdit { + pub fn new() -> Self { + Self { + edit_table: HashMap::new(), + } + } + + pub fn create_log(&mut self, log_id: TxLogId, bucket: BucketName, key: Key) { + let already_created = self.edit_table.insert( + log_id, + TxLogEdit::Create(TxLogCreate { + bucket, + key, + root_mht: None, + }), + ); + debug_assert!(already_created.is_none()); + } + + pub fn append_log(&mut self, log_id: TxLogId, root_mht: RootMhtMeta) { + let already_existed = self + .edit_table + .insert(log_id, TxLogEdit::Append(TxLogAppend { root_mht })); + debug_assert!(already_existed.is_none()); + } + + pub fn delete_log(&mut self, log_id: TxLogId) { + match self.edit_table.get_mut(&log_id) { + None => { + let _ = self.edit_table.insert(log_id, TxLogEdit::Delete); + } + Some(TxLogEdit::Create(_)) | Some(TxLogEdit::Move(_)) => { + let _ = self.edit_table.insert(log_id, TxLogEdit::Delete); + } + Some(TxLogEdit::Append(_)) => { + panic!( + "append edit is added at very late stage, after which logs won't get deleted" + ); + } + Some(TxLogEdit::Delete) => { + panic!("can't delete a deleted log"); + } + } + } + + pub fn move_log(&mut self, log_id: TxLogId, from_bucket: &str, to_bucket: &str) { + let move_edit = TxLogEdit::Move(TxLogMove { + from: from_bucket.to_string(), + to: to_bucket.to_string(), + }); + let edit_existed = self.edit_table.insert(log_id, move_edit); + debug_assert!(edit_existed.is_none()); + } + + pub fn is_log_created(&self, log_id: TxLogId) -> bool { + match self.edit_table.get(&log_id) { + Some(TxLogEdit::Create(_)) | Some(TxLogEdit::Append(_)) | Some(TxLogEdit::Move(_)) => { + true + } + None | Some(TxLogEdit::Delete) => false, + } + } + + pub fn is_log_deleted(&self, log_id: TxLogId) -> bool { + matches!(self.edit_table.get(&log_id), Some(TxLogEdit::Delete)) + } + + pub fn iter_created_logs(&self) -> impl Iterator<Item = &TxLogId> { + self.edit_table + .iter() + .filter(|(_, edit)| matches!(edit, TxLogEdit::Create(_))) + .map(|(id, _)| id) + } + + pub fn iter_deleted_logs(&self) -> impl Iterator<Item = &TxLogId> { + self.edit_table + .iter() + .filter(|(_, edit)| matches!(edit, TxLogEdit::Delete)) + .map(|(id, _)| id) + } + + pub fn update_log_meta(&mut self, meta: (TxLogId, RootMhtMeta)) { + // For newly-created logs and existing logs + // that are appended, update `RootMhtMeta` + match self.edit_table.get_mut(&meta.0) { + None | Some(TxLogEdit::Delete) | Some(TxLogEdit::Move(_)) => { + unreachable!(); + } + Some(TxLogEdit::Create(create)) => { + let _ = create.root_mht.insert(meta.1); + } + Some(TxLogEdit::Append(append)) => { + append.root_mht = meta.1; + } + } + } + + pub fn is_empty(&self) -> bool { + self.edit_table.is_empty() + } +} + +impl Edit<TxLogStoreState> for TxLogStoreEdit { + fn apply_to(&self, state: &mut TxLogStoreState) { + for (&log_id, log_edit) in &self.edit_table { + match log_edit { + TxLogEdit::Create(create_edit) => { + let TxLogCreate { + bucket, + key, + root_mht, + .. + } = create_edit; + state.create_log( + log_id, + bucket.clone(), + *key, + root_mht.clone().expect("root mht not found in created log"), + ); + } + TxLogEdit::Append(append_edit) => { + let TxLogAppend { root_mht, .. } = append_edit; + state.append_log(log_id, root_mht.clone()); + } + TxLogEdit::Delete => { + state.delete_log(log_id); + } + TxLogEdit::Move(move_edit) => { + state.move_log(log_id, &move_edit.from, &move_edit.to) + } + } + } + } +} + +impl TxData for TxLogStoreEdit {} + +impl<D> OpenLogTable<D> { + pub fn new() -> Self { + Self { + open_table: HashMap::new(), + } + } +} + +impl OpenLogCache { + pub fn new() -> Self { + Self { + open_table: HashMap::new(), + } + } +} + +impl<D: Any + Send + Sync + 'static> TxData for OpenLogTable<D> {} +impl TxData for OpenLogCache {} + +//////////////////////////////////////////////////////////////////////////////// +// Journaling +//////////////////////////////////////////////////////////////////////////////// + +mod journaling { + use super::*; + use crate::layers::edit::DefaultCompactPolicy; + + pub type Journal<D> = EditJournal<AllEdit, AllState, D, JournalCompactPolicy>; + pub type JournalCompactPolicy = DefaultCompactPolicy; + + #[derive(Clone, Debug, Serialize, Deserialize)] + pub struct AllState { + pub chunk_alloc: ChunkAllocState, + pub raw_log_store: RawLogStoreState, + pub tx_log_store: TxLogStoreState, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct AllEdit { + pub chunk_edit: ChunkAllocEdit, + pub raw_log_edit: RawLogStoreEdit, + pub tx_log_edit: TxLogStoreEdit, + } + + impl Edit<AllState> for AllEdit { + fn apply_to(&self, state: &mut AllState) { + if !self.tx_log_edit.is_empty() { + self.tx_log_edit.apply_to(&mut state.tx_log_store); + } + if !self.raw_log_edit.is_empty() { + self.raw_log_edit.apply_to(&mut state.raw_log_store); + } + if !self.chunk_edit.is_empty() { + self.chunk_edit.apply_to(&mut state.chunk_alloc); + } + } + } + + impl AllEdit { + pub fn from_chunk_edit(chunk_edit: &ChunkAllocEdit) -> Self { + Self { + chunk_edit: chunk_edit.clone(), + raw_log_edit: RawLogStoreEdit::new(), + tx_log_edit: TxLogStoreEdit::new(), + } + } + + pub fn from_raw_log_edit(raw_log_edit: &RawLogStoreEdit) -> Self { + Self { + chunk_edit: ChunkAllocEdit::new(), + raw_log_edit: raw_log_edit.clone(), + tx_log_edit: TxLogStoreEdit::new(), + } + } + + pub fn from_tx_log_edit(tx_log_edit: &TxLogStoreEdit) -> Self { + Self { + chunk_edit: ChunkAllocEdit::new(), + raw_log_edit: RawLogStoreEdit::new(), + tx_log_edit: tx_log_edit.clone(), + } + } + } +} + +#[cfg(test)] +mod tests { + use std::thread::{self, JoinHandle}; + + use super::*; + use crate::layers::bio::{Buf, MemDisk}; + + #[test] + fn tx_log_store_fns() -> Result<()> { + let nblocks = 4 * CHUNK_NBLOCKS; + let mem_disk = MemDisk::create(nblocks)?; + let disk = mem_disk.clone(); + let root_key = Key::random(); + let tx_log_store = TxLogStore::format(mem_disk, root_key.clone())?; + let bucket = "TEST"; + let content = 5_u8; + + // TX 1: create a new log and append contents (committed) + let mut tx = tx_log_store.new_tx(); + let res: Result<TxLogId> = tx.context(|| { + let new_log = tx_log_store.create_log(bucket)?; + let log_id = new_log.id(); + assert_eq!(log_id, 0); + assert_eq!(new_log.tx_id(), tx_log_store.current_tx().id()); + assert_eq!(new_log.can_append(), true); + let mut buf = Buf::alloc(1)?; + buf.as_mut_slice().fill(content); + new_log.append(buf.as_ref())?; + + assert_eq!(new_log.nblocks(), 1); + assert_eq!(new_log.bucket(), bucket); + Ok(log_id) + }); + let log_id = res?; + tx.commit()?; + + // TX 2: open the log then read (committed) + let mut tx = tx_log_store.new_tx(); + let res: Result<_> = tx.context(|| { + let log_list = tx_log_store.list_logs_in(bucket)?; + assert_eq!(log_list, vec![log_id]); + assert_eq!(tx_log_store.contains_log(log_id), true); + assert_eq!(tx_log_store.contains_log(1), false); + + let log = tx_log_store.open_log(0, false)?; + assert_eq!(log.id(), log_id); + assert_eq!(log.tx_id(), tx_log_store.current_tx().id()); + let mut buf = Buf::alloc(1)?; + log.read(0, buf.as_mut())?; + assert_eq!(buf.as_slice()[0], content); + + let log = tx_log_store.open_log_in(bucket)?; + assert_eq!(log.id(), log_id); + log.read(0 as BlockId, buf.as_mut())?; + assert_eq!(buf.as_slice()[0], content); + Ok(()) + }); + res?; + tx.commit()?; + + // Recover the tx log store + tx_log_store.sync()?; + drop(tx_log_store); + let recovered_store = TxLogStore::recover(disk, root_key)?; + + // TX 3: create a new log from recovered_store (aborted) + let tx_log_store = recovered_store.clone(); + let handler = thread::spawn(move || -> Result<TxLogId> { + let mut tx = tx_log_store.new_tx(); + let res: Result<_> = tx.context(|| { + let new_log = tx_log_store.create_log(bucket)?; + assert_eq!(tx_log_store.list_logs_in(bucket)?.len(), 2); + Ok(new_log.id()) + }); + tx.abort(); + res + }); + let new_log_id = handler.join().unwrap()?; + + recovered_store + .state + .lock() + .persistent + .find_log(new_log_id) + .expect_err("log not found"); + + Ok(()) + } + + #[test] + fn tx_log_deletion() -> Result<()> { + let tx_log_store = TxLogStore::format(MemDisk::create(4 * CHUNK_NBLOCKS)?, Key::random())?; + + let mut tx = tx_log_store.new_tx(); + let content = 5_u8; + let res: Result<_> = tx.context(|| { + let new_log = tx_log_store.create_log("TEST")?; + let mut buf = Buf::alloc(1)?; + buf.as_mut_slice().fill(content); + new_log.append(buf.as_ref())?; + Ok(new_log.id()) + }); + let log_id = res?; + tx.commit()?; + + let handlers = (0..16) + .map(|_| { + let tx_log_store = tx_log_store.clone(); + thread::spawn(move || -> Result<()> { + let mut tx = tx_log_store.new_tx(); + println!( + "TX[{:?}] executed on thread[{:?}]", + tx.id(), + crate::os::CurrentThread::id() + ); + let _ = tx.context(|| { + let log = tx_log_store.open_log(log_id, false)?; + assert_eq!(log.id(), log_id); + assert_eq!(log.tx_id(), tx_log_store.current_tx().id()); + let mut buf = Buf::alloc(1)?; + log.read(0 as BlockId, buf.as_mut())?; + assert_eq!(buf.as_slice(), &[content; BLOCK_SIZE]); + tx_log_store.delete_log(log_id) + }); + tx.commit() + }) + }) + .collect::<Vec<JoinHandle<Result<()>>>>(); + for handler in handlers { + handler.join().unwrap()?; + } + + let mut tx = tx_log_store.new_tx(); + let _ = tx.context(|| { + let res = tx_log_store.open_log(log_id, false).map(|_| ()); + res.expect_err("result must be NotFound"); + }); + tx.commit() + } +} diff --git a/kernel/comps/mlsdisk/src/layers/4-lsm/compaction.rs b/kernel/comps/mlsdisk/src/layers/4-lsm/compaction.rs new file mode 100644 index 00000000..8e9073ea --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/4-lsm/compaction.rs @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Compaction in `TxLsmTree`. +use core::marker::PhantomData; + +use super::{ + mem_table::ValueEx, sstable::SSTable, tx_lsm_tree::SSTABLE_CAPACITY, LsmLevel, RecordKey, + RecordValue, SyncId, TxEventListener, +}; +use crate::{ + layers::{bio::BlockSet, log::TxLogStore}, + os::{JoinHandle, Mutex}, + prelude::*, +}; + +/// A `Compactor` is currently used for asynchronous compaction +/// and specific compaction algorithm of `TxLsmTree`. +pub(super) struct Compactor<K, V> { + handle: Mutex<Option<JoinHandle<Result<()>>>>, + phantom: PhantomData<(K, V)>, +} + +impl<K: RecordKey<K>, V: RecordValue> Compactor<K, V> { + /// Create a new `Compactor` instance. + pub fn new() -> Self { + Self { + handle: Mutex::new(None), + phantom: PhantomData, + } + } + + /// Record current compaction thread handle. + pub fn record_handle(&self, handle: JoinHandle<Result<()>>) { + let mut handle_opt = self.handle.lock(); + assert!(handle_opt.is_none()); + let _ = handle_opt.insert(handle); + } + + /// Wait until the compaction is finished. + pub fn wait_compaction(&self) -> Result<()> { + if let Some(handle) = self.handle.lock().take() { + handle.join().unwrap() + } else { + Ok(()) + } + } + + /// Core function for compacting overlapped records and building new SSTs. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn compact_records_and_build_ssts<D: BlockSet + 'static>( + upper_records: impl Iterator<Item = (K, ValueEx<V>)>, + lower_records: impl Iterator<Item = (K, ValueEx<V>)>, + tx_log_store: &Arc<TxLogStore<D>>, + event_listener: &Arc<dyn TxEventListener<K, V>>, + to_level: LsmLevel, + sync_id: SyncId, + ) -> Result<Vec<SSTable<K, V>>> { + let mut created_ssts = Vec::new(); + let mut upper_iter = upper_records.peekable(); + let mut lower_iter = lower_records.peekable(); + + loop { + let mut record_cnt = 0; + let records_iter = core::iter::from_fn(|| { + if record_cnt == SSTABLE_CAPACITY { + return None; + } + + record_cnt += 1; + match (upper_iter.peek(), lower_iter.peek()) { + (Some((upper_k, _)), Some((lower_k, _))) => match upper_k.cmp(lower_k) { + core::cmp::Ordering::Less => upper_iter.next(), + core::cmp::Ordering::Greater => lower_iter.next(), + core::cmp::Ordering::Equal => { + let (k, new_v_ex) = upper_iter.next().unwrap(); + let (_, old_v_ex) = lower_iter.next().unwrap(); + let (next_v_ex, dropped_v_opt) = + Self::compact_value_ex(new_v_ex, old_v_ex); + + if let Some(dropped_v) = dropped_v_opt { + event_listener.on_drop_record(&(k, dropped_v)).unwrap(); + } + Some((k, next_v_ex)) + } + }, + (Some(_), None) => upper_iter.next(), + (None, Some(_)) => lower_iter.next(), + (None, None) => None, + } + }); + let mut records_iter = records_iter.peekable(); + if records_iter.peek().is_none() { + break; + } + + let new_log = tx_log_store.create_log(to_level.bucket())?; + let new_sst = SSTable::build(records_iter, sync_id, &new_log, None)?; + + created_ssts.push(new_sst); + } + + Ok(created_ssts) + } + + /// Compact two `ValueEx<V>`s with the same key, returning + /// the compacted value and the dropped value if any. + fn compact_value_ex(new: ValueEx<V>, old: ValueEx<V>) -> (ValueEx<V>, Option<V>) { + match (new, old) { + (ValueEx::Synced(new_v), ValueEx::Synced(old_v)) => { + (ValueEx::Synced(new_v), Some(old_v)) + } + (ValueEx::Unsynced(new_v), ValueEx::Synced(old_v)) => { + (ValueEx::SyncedAndUnsynced(old_v, new_v), None) + } + (ValueEx::Unsynced(new_v), ValueEx::Unsynced(old_v)) => { + (ValueEx::Unsynced(new_v), Some(old_v)) + } + (ValueEx::Unsynced(new_v), ValueEx::SyncedAndUnsynced(old_sv, old_usv)) => { + (ValueEx::SyncedAndUnsynced(old_sv, new_v), Some(old_usv)) + } + (ValueEx::SyncedAndUnsynced(new_sv, new_usv), ValueEx::Synced(old_sv)) => { + (ValueEx::SyncedAndUnsynced(new_sv, new_usv), Some(old_sv)) + } + _ => { + unreachable!() + } + } + } +} diff --git a/kernel/comps/mlsdisk/src/layers/4-lsm/mem_table.rs b/kernel/comps/mlsdisk/src/layers/4-lsm/mem_table.rs new file mode 100644 index 00000000..b5f77957 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/4-lsm/mem_table.rs @@ -0,0 +1,402 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! MemTable. +use core::ops::Range; + +use super::{tx_lsm_tree::OnDropRecodeFn, AsKV, RangeQueryCtx, RecordKey, RecordValue, SyncId}; +use crate::{ + os::{BTreeMap, Condvar, CvarMutex, Mutex, RwLock, RwLockReadGuard}, + prelude::*, +}; + +/// Manager for an mutable `MemTable` and an immutable `MemTable` +/// in a `TxLsmTree`. +pub(super) struct MemTableManager<K: RecordKey<K>, V> { + mutable: Mutex<MemTable<K, V>>, + immutable: RwLock<MemTable<K, V>>, // Read-only most of the time + cvar: Condvar, + is_full: CvarMutex<bool>, +} + +/// MemTable for LSM-Tree. +/// +/// Manages organized key-value records in memory with a capacity. +/// Each `MemTable` is sync-aware (tagged with current sync ID). +/// Both synced and unsynced records can co-exist. +/// Also supports user-defined callback when a record is dropped. +pub(super) struct MemTable<K: RecordKey<K>, V> { + table: BTreeMap<K, ValueEx<V>>, + size: usize, + cap: usize, + sync_id: SyncId, + unsynced_range: Option<Range<K>>, + on_drop_record: Option<Arc<OnDropRecodeFn<K, V>>>, +} + +/// An extended value which is sync-aware. +/// At most one unsynced and one synced records can coexist at the same time. +#[derive(Clone, Debug)] +pub(super) enum ValueEx<V> { + Synced(V), + Unsynced(V), + SyncedAndUnsynced(V, V), +} + +impl<K: RecordKey<K>, V: RecordValue> MemTableManager<K, V> { + /// Creates a new `MemTableManager` given the current master sync ID, + /// the capacity and the callback when dropping records. + pub fn new( + sync_id: SyncId, + capacity: usize, + on_drop_record_in_memtable: Option<Arc<OnDropRecodeFn<K, V>>>, + ) -> Self { + let mutable = Mutex::new(MemTable::new( + capacity, + sync_id, + on_drop_record_in_memtable.clone(), + )); + let immutable = RwLock::new(MemTable::new(capacity, sync_id, on_drop_record_in_memtable)); + + Self { + mutable, + immutable, + cvar: Condvar::new(), + is_full: CvarMutex::new(false), + } + } + + /// Gets the target value of the given key from the `MemTable`s. + pub fn get(&self, key: &K) -> Option<V> { + if let Some(value) = self.mutable.lock().get(key) { + return Some(*value); + } + + if let Some(value) = self.immutable.read().get(key) { + return Some(*value); + } + + None + } + + /// Gets the range of values from the `MemTable`s. + pub fn get_range(&self, range_query_ctx: &mut RangeQueryCtx<K, V>) -> bool { + let is_completed = self.mutable.lock().get_range(range_query_ctx); + if is_completed { + return is_completed; + } + + self.immutable.read().get_range(range_query_ctx) + } + + /// Puts a key-value pair into the mutable `MemTable`, and + /// return whether the mutable `MemTable` is full. + pub fn put(&self, key: K, value: V) -> bool { + let mut is_full = self.is_full.lock().unwrap(); + while *is_full { + is_full = self.cvar.wait(is_full).unwrap(); + } + debug_assert!(!*is_full); + + let mut mutable = self.mutable.lock(); + let _ = mutable.put(key, value); + + if mutable.at_capacity() { + *is_full = true; + } + *is_full + } + + /// Sync the mutable `MemTable` with the given sync ID. + pub fn sync(&self, sync_id: SyncId) { + self.mutable.lock().sync(sync_id) + } + + /// Switch two `MemTable`s. Should only be called in a situation that + /// the mutable `MemTable` becomes full and the immutable `MemTable` is + /// ready to be cleared. + pub fn switch(&self) -> Result<()> { + let mut is_full = self.is_full.lock().unwrap(); + debug_assert!(*is_full); + + let mut mutable = self.mutable.lock(); + let sync_id = mutable.sync_id(); + + let mut immutable = self.immutable.write(); + immutable.clear(); + + core::mem::swap(&mut *mutable, &mut *immutable); + + debug_assert!(mutable.is_empty() && immutable.at_capacity()); + // Update sync ID of the switched mutable `MemTable` + mutable.sync(sync_id); + + *is_full = false; + self.cvar.notify_all(); + Ok(()) + } + + /// Gets the immutable `MemTable` instance (read-only). + pub fn immutable_memtable(&self) -> RwLockReadGuard<MemTable<K, V>> { + self.immutable.read() + } +} + +impl<K: RecordKey<K>, V: RecordValue> MemTable<K, V> { + /// Creates a new `MemTable`, given the capacity, the current sync ID, + /// and the callback of dropping record. + pub fn new( + cap: usize, + sync_id: SyncId, + on_drop_record: Option<Arc<OnDropRecodeFn<K, V>>>, + ) -> Self { + Self { + table: BTreeMap::new(), + size: 0, + cap, + sync_id, + unsynced_range: None, + on_drop_record, + } + } + + /// Gets the target value given the key. + pub fn get(&self, key: &K) -> Option<&V> { + let value_ex = self.table.get(key)?; + Some(value_ex.get()) + } + + /// Range query, returns whether the request is completed. + pub fn get_range(&self, range_query_ctx: &mut RangeQueryCtx<K, V>) -> bool { + debug_assert!(!range_query_ctx.is_completed()); + let target_range = range_query_ctx.range_uncompleted().unwrap(); + + for (k, v_ex) in self.table.range(target_range) { + range_query_ctx.complete(*k, *v_ex.get()); + } + + range_query_ctx.is_completed() + } + + /// Puts a new K-V record to the table, drop the old one. + pub fn put(&mut self, key: K, value: V) -> Option<V> { + let dropped_value = if let Some(value_ex) = self.table.get_mut(&key) { + if let Some(dropped) = value_ex.put(value) { + let _ = self + .on_drop_record + .as_ref() + .map(|on_drop_record| on_drop_record(&(key, dropped))); + Some(dropped) + } else { + self.size += 1; + None + } + } else { + let _ = self.table.insert(key, ValueEx::new(value)); + self.size += 1; + None + }; + + if let Some(range) = self.unsynced_range.as_mut() { + if range.is_empty() { + *range = key..key + 1; + } else { + let start = key.min(range.start); + let end = (key + 1).max(range.end); + *range = start..end; + } + } + dropped_value + } + + /// Sync the table, update the sync ID, drop the replaced one. + pub fn sync(&mut self, sync_id: SyncId) { + debug_assert!(self.sync_id <= sync_id); + if self.sync_id == sync_id { + return; + } + + let filter_unsynced: Box<dyn Iterator<Item = _>> = if let Some(range) = &self.unsynced_range + { + Box::new( + self.table + .range_mut(range.clone()) + .filter(|(_, v_ex)| v_ex.contains_unsynced()), + ) + } else { + Box::new( + self.table + .iter_mut() + .filter(|(_, v_ex)| v_ex.contains_unsynced()), + ) + }; + for (k, v_ex) in filter_unsynced { + if let Some(dropped) = v_ex.sync() { + let _ = self + .on_drop_record + .as_ref() + .map(|on_drop_record| on_drop_record(&(*k, dropped))); + self.size -= 1; + } + } + + self.sync_id = sync_id; + // Insert an empty range upon first sync + let _ = self + .unsynced_range + .get_or_insert_with(|| K::new_uninit()..K::new_uninit()); + } + + /// Return the sync ID of this table. + pub fn sync_id(&self) -> SyncId { + self.sync_id + } + + /// Return an iterator over the table. + pub fn iter(&self) -> impl Iterator<Item = (&K, &ValueEx<V>)> { + self.table.iter() + } + + /// Return the number of records in the table. + pub fn size(&self) -> usize { + self.size + } + + /// Return whether the table is empty. + pub fn is_empty(&self) -> bool { + self.size == 0 + } + + /// Return whether the table is full. + pub fn at_capacity(&self) -> bool { + self.size >= self.cap + } + + /// Clear all records from the table. + pub fn clear(&mut self) { + self.table.clear(); + self.size = 0; + self.unsynced_range = None; + } +} + +impl<V: RecordValue> ValueEx<V> { + /// Creates a new unsynced value. + pub fn new(value: V) -> Self { + Self::Unsynced(value) + } + + /// Gets the most recent value. + pub fn get(&self) -> &V { + match self { + Self::Synced(v) => v, + Self::Unsynced(v) => v, + Self::SyncedAndUnsynced(_, v) => v, + } + } + + /// Puts a new value, return the replaced value if any. + fn put(&mut self, value: V) -> Option<V> { + let existed = core::mem::take(self); + + match existed { + ValueEx::Synced(v) => { + *self = Self::SyncedAndUnsynced(v, value); + None + } + ValueEx::Unsynced(v) => { + *self = Self::Unsynced(value); + Some(v) + } + ValueEx::SyncedAndUnsynced(sv, usv) => { + *self = Self::SyncedAndUnsynced(sv, value); + Some(usv) + } + } + } + + /// Sync the value, return the replaced value if any. + fn sync(&mut self) -> Option<V> { + debug_assert!(self.contains_unsynced()); + let existed = core::mem::take(self); + + match existed { + ValueEx::Unsynced(v) => { + *self = Self::Synced(v); + None + } + ValueEx::SyncedAndUnsynced(sv, usv) => { + *self = Self::Synced(usv); + Some(sv) + } + ValueEx::Synced(_) => unreachable!(), + } + } + + /// Whether the value contains an unsynced value. + pub fn contains_unsynced(&self) -> bool { + match self { + ValueEx::Unsynced(_) | ValueEx::SyncedAndUnsynced(_, _) => true, + ValueEx::Synced(_) => false, + } + } +} + +impl<V: RecordValue> Default for ValueEx<V> { + fn default() -> Self { + Self::Unsynced(V::new_uninit()) + } +} + +impl<K: RecordKey<K>, V: RecordValue> Debug for MemTableManager<K, V> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MemTableManager") + .field("mutable_memtable_size", &self.mutable.lock().size()) + .field("immutable_memtable_size", &self.immutable_memtable().size()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use core::sync::atomic::{AtomicU16, Ordering}; + + use super::*; + + #[test] + fn memtable_fns() -> Result<()> { + impl RecordValue for u16 {} + let drop_count = Arc::new(AtomicU16::new(0)); + let dc = drop_count.clone(); + let drop_fn = move |_: &dyn AsKV<usize, u16>| { + dc.fetch_add(1, Ordering::Relaxed); + }; + let mut table = MemTable::<usize, u16>::new(4, 0, Some(Arc::new(drop_fn))); + + table.put(1, 11); + table.put(2, 12); + table.put(2, 22); + assert_eq!(drop_count.load(Ordering::Relaxed), 1); + assert_eq!(table.size(), 2); + assert_eq!(table.at_capacity(), false); + + table.sync(1); + table.put(2, 32); + assert_eq!(table.size(), 3); + assert_eq!(*table.get(&2).unwrap(), 32); + + table.sync(2); + assert_eq!(drop_count.load(Ordering::Relaxed), 2); + table.put(2, 52); + table.put(3, 13); + assert_eq!(table.at_capacity(), true); + + let mut range_query_ctx = RangeQueryCtx::new(2, 2); + assert_eq!(table.get_range(&mut range_query_ctx), true); + assert_eq!(range_query_ctx.into_results(), vec![(2, 52), (3, 13)]); + + assert_eq!(table.sync_id(), 2); + table.clear(); + assert_eq!(table.is_empty(), true); + Ok(()) + } +} diff --git a/kernel/comps/mlsdisk/src/layers/4-lsm/mod.rs b/kernel/comps/mlsdisk/src/layers/4-lsm/mod.rs new file mode 100644 index 00000000..d6da1aa7 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/4-lsm/mod.rs @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! The layer of transactional Lsm-Tree. +//! +//! This module provides the implementation for `TxLsmTree`. +//! `TxLsmTree` is similar to general-purpose LSM-Tree, supporting `put()`, `get()`, `get_range()` +//! key-value records, which are managed in MemTables and SSTables. +//! +//! `TxLsmTree` is transactional in the sense that +//! 1) it supports `sync()` that guarantees changes are persisted atomically and irreversibly, +//! synchronized records and unsynchronized records can co-existed. +//! 2) its internal data is securely stored in `TxLogStore` (L3) and updated in transactions for consistency, +//! WALs and SSTables are stored and managed in `TxLogStore`. +//! +//! `TxLsmTree` supports piggybacking callbacks during compaction and recovery. +//! +//! # Usage Example +//! +//! Create a `TxLsmTree` then put some records into it. +//! +//! ``` +//! // Prepare an underlying disk (implement `BlockSet`) first +//! let nblocks = 1024; +//! let mem_disk = MemDisk::create(nblocks)?; +//! +//! // Prepare an underlying `TxLogStore` (L3) for storing WALs and SSTs +//! let tx_log_store = Arc::new(TxLogStore::format(mem_disk)?); +//! +//! // Create a `TxLsmTree` with the created `TxLogStore` +//! let tx_lsm_tree: TxLsmTree<BlockId, String, MemDisk> = +//! TxLsmTree::format(tx_log_store, Arc::new(YourFactory), None)?; +//! +//! // Put some key-value records into the tree +//! for i in 0..10 { +//! let k = i as BlockId; +//! let v = i.to_string(); +//! tx_lsm_tree.put(k, v)?; +//! } +//! +//! // Issue a sync operation to the tree to ensure persistency +//! tx_lsm_tree.sync()?; +//! +//! // Use `get()` (or `get_range()`) to query the tree +//! let target_value = tx_lsm_tree.get(&5).unwrap(); +//! // Check the previously put value +//! assert_eq(target_value, "5"); +//! +//! // `TxLsmTree` supports user-defined per-TX callbacks +//! struct YourFactory; +//! struct YourListener; +//! +//! impl<K, V> TxEventListenerFactory<K, V> for YourFactory { +//! // Support create per-TX (upon compaction or upon recovery) listener +//! fn new_event_listener(&self, tx_type: TxType) -> Arc<dyn TxEventListener<K, V>> { +//! Arc::new(YourListener::new(tx_type)) +//! } +//! } +//! +//! // Support defining callbacks when record is added or drop, or +//! // at some critical points during a TX +//! impl<K, V> TxEventListener<K, V> for YourListener { +//! /* details omitted, see the API for more */ +//! } +//! ``` + +mod compaction; +mod mem_table; +mod range_query_ctx; +mod sstable; +mod tx_lsm_tree; +mod wal; + +pub use self::{ + range_query_ctx::RangeQueryCtx, + tx_lsm_tree::{ + AsKV, LsmLevel, RecordKey, RecordValue, SyncId, SyncIdStore, TxEventListener, + TxEventListenerFactory, TxLsmTree, TxType, + }, +}; diff --git a/kernel/comps/mlsdisk/src/layers/4-lsm/range_query_ctx.rs b/kernel/comps/mlsdisk/src/layers/4-lsm/range_query_ctx.rs new file mode 100644 index 00000000..2a3e1c4e --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/4-lsm/range_query_ctx.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MPL-2.0 + +// Context for range query. +use core::ops::RangeInclusive; + +use super::{RecordKey, RecordValue}; +use crate::{prelude::*, util::BitMap}; + +/// Context for a range query request. +/// It tracks the completing process of each slot within the range. +/// A "slot" indicates one specific key-value pair of the query. +#[derive(Debug)] +pub struct RangeQueryCtx<K, V> { + start: K, + num_values: usize, + complete_table: BitMap, + min_uncompleted: usize, + res: Vec<(K, V)>, +} + +impl<K: RecordKey<K>, V: RecordValue> RangeQueryCtx<K, V> { + /// Create a new context with the given start key, + /// and the number of values for query. + pub fn new(start: K, num_values: usize) -> Self { + Self { + start, + num_values, + complete_table: BitMap::repeat(false, num_values), + min_uncompleted: 0, + res: Vec::with_capacity(num_values), + } + } + + /// Gets the uncompleted range within the whole, returns `None` + /// if all slots are already completed. + pub fn range_uncompleted(&self) -> Option<RangeInclusive<K>> { + if self.is_completed() { + return None; + } + debug_assert!(self.min_uncompleted < self.num_values); + + let first_uncompleted = self.start + self.min_uncompleted; + let last_uncompleted = self.start + self.complete_table.last_zero()?; + Some(first_uncompleted..=last_uncompleted) + } + + /// Whether the uncompleted range contains the target key. + pub fn contains_uncompleted(&self, key: &K) -> bool { + let nth = *key - self.start; + nth < self.num_values && !self.complete_table[nth] + } + + /// Whether the range query context is completed, means + /// all slots are filled with the corresponding values. + pub fn is_completed(&self) -> bool { + self.min_uncompleted == self.num_values + } + + /// Complete one slot within the range, with the specific + /// key and the queried value. + pub fn complete(&mut self, key: K, value: V) { + let nth = key - self.start; + if self.complete_table[nth] { + return; + } + + self.res.push((key, value)); + self.complete_table.set(nth, true); + self.update_min_uncompleted(nth); + } + + /// Mark the specific slot as completed. + pub fn mark_completed(&mut self, key: K) { + let nth = key - self.start; + + self.complete_table.set(nth, true); + self.update_min_uncompleted(nth); + } + + /// Turn the context into final results. + pub fn into_results(self) -> Vec<(K, V)> { + debug_assert!(self.is_completed()); + self.res + } + + fn update_min_uncompleted(&mut self, completed_nth: usize) { + if self.min_uncompleted == completed_nth { + if let Some(next_uncompleted) = self.complete_table.first_zero(completed_nth) { + self.min_uncompleted = next_uncompleted; + } else { + // Indicate all slots are completed + self.min_uncompleted = self.num_values; + } + } + } +} diff --git a/kernel/comps/mlsdisk/src/layers/4-lsm/sstable.rs b/kernel/comps/mlsdisk/src/layers/4-lsm/sstable.rs new file mode 100644 index 00000000..c613308b --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/4-lsm/sstable.rs @@ -0,0 +1,779 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Sorted String Table. +use alloc::vec; +use core::{marker::PhantomData, mem::size_of, num::NonZeroUsize, ops::RangeInclusive}; + +use lru::LruCache; +use ostd_pod::Pod; + +use super::{ + mem_table::ValueEx, tx_lsm_tree::AsKVex, RangeQueryCtx, RecordKey, RecordValue, SyncId, + TxEventListener, +}; +use crate::{ + layers::{ + bio::{BlockSet, Buf, BufMut, BufRef, BID_SIZE}, + log::{TxLog, TxLogId, TxLogStore}, + }, + os::Mutex, + prelude::*, +}; + +/// Sorted String Table (SST) for `TxLsmTree`. +/// +/// Responsible for storing, managing key-value records on a `TxLog` (L3). +/// Records are serialized, sorted, organized on the `TxLog`. +/// Supports three access modes: point query, range query and whole scan. +pub(super) struct SSTable<K, V> { + id: TxLogId, + footer: Footer<K>, + cache: Mutex<LruCache<BlockId, Arc<RecordBlock>>>, + phantom: PhantomData<(K, V)>, +} + +/// Footer of a `SSTable`, contains metadata of itself +/// index entries for locating record blocks. +#[derive(Debug)] +struct Footer<K> { + meta: FooterMeta, + index: Vec<IndexEntry<K>>, +} + +/// Footer metadata to describe a `SSTable`. +#[repr(C)] +#[derive(Copy, Clone, Pod, Debug)] +struct FooterMeta { + num_index: u16, + index_nblocks: u16, + total_records: u32, + record_block_size: u32, + sync_id: SyncId, +} +const FOOTER_META_SIZE: usize = size_of::<FooterMeta>(); + +/// Index entry to describe a `RecordBlock` in a `SSTable`. +#[derive(Debug)] +struct IndexEntry<K> { + pos: BlockId, + first: K, + last: K, +} + +/// A block full of serialized records. +struct RecordBlock { + buf: Vec<u8>, +} +const RECORD_BLOCK_NBLOCKS: usize = 32; +/// The size of a `RecordBlock`, which is a multiple of `BLOCK_SIZE`. +const RECORD_BLOCK_SIZE: usize = RECORD_BLOCK_NBLOCKS * BLOCK_SIZE; + +/// Accessor for a query. +enum QueryAccessor<K> { + Point(K), + Range(RangeInclusive<K>), +} + +/// Iterator over `RecordBlock` for query purpose. +struct BlockQueryIter<'a, K, V> { + block: &'a RecordBlock, + offset: usize, + accessor: &'a QueryAccessor<K>, + phantom: PhantomData<(K, V)>, +} + +/// Accessor for a whole table scan. +struct ScanAccessor<'a, K, V> { + all_synced: bool, + discard_unsynced: bool, + event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>, +} + +/// Iterator over `RecordBlock` for scan purpose. +struct BlockScanIter<'a, K, V> { + block: Arc<RecordBlock>, + offset: usize, + accessor: ScanAccessor<'a, K, V>, +} + +/// Iterator over `SSTable`. +pub(super) struct SstIter<'a, K, V, D> { + sst: &'a SSTable<K, V>, + curr_nth_index: usize, + curr_rb_iter: Option<BlockScanIter<'a, K, V>>, + tx_log_store: &'a Arc<TxLogStore<D>>, +} + +/// Format on a `TxLog`: +/// +/// ```text +/// | [Record] | [Record] |...| Footer | +/// |K|flag|V(V)| ... | [Record] |...| [IndexEntry] | FooterMeta | +/// |RECORD_BLOCK_SIZE|RECORD_BLOCK_SIZE|...| | +/// ``` +impl<K: RecordKey<K>, V: RecordValue> SSTable<K, V> { + const K_SIZE: usize = size_of::<K>(); + const V_SIZE: usize = size_of::<V>(); + const FLAG_SIZE: usize = size_of::<RecordFlag>(); + const MIN_RECORD_SIZE: usize = BID_SIZE + Self::FLAG_SIZE + Self::V_SIZE; + const MAX_RECORD_SIZE: usize = BID_SIZE + Self::FLAG_SIZE + 2 * Self::V_SIZE; + const INDEX_ENTRY_SIZE: usize = BID_SIZE + 2 * Self::K_SIZE; + const CACHE_CAP: usize = 1024; + + /// Return the ID of this `SSTable`, which is the same ID + /// to the underlying `TxLog`. + pub fn id(&self) -> TxLogId { + self.id + } + + /// Return the sync ID of this `SSTable`, it may be smaller than the + /// current master sync ID. + pub fn sync_id(&self) -> SyncId { + self.footer.meta.sync_id + } + + /// The range of keys covered by this `SSTable`. + pub fn range(&self) -> RangeInclusive<K> { + RangeInclusive::new( + self.footer.index[0].first, + self.footer.index[self.footer.meta.num_index as usize - 1].last, + ) + } + + /// Whether the target key is within the range, "within the range" doesn't mean + /// the `SSTable` do have this key. + pub fn is_within_range(&self, key: &K) -> bool { + self.range().contains(key) + } + + /// Whether the target range is overlapped with the range of this `SSTable`. + pub fn overlap_with(&self, rhs_range: &RangeInclusive<K>) -> bool { + let lhs_range = self.range(); + !(lhs_range.end() < rhs_range.start() || lhs_range.start() > rhs_range.end()) + } + + // Accessing functions below + + /// Point query. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn access_point<D: BlockSet + 'static>( + &self, + key: &K, + tx_log_store: &Arc<TxLogStore<D>>, + ) -> Result<V> { + debug_assert!(self.range().contains(key)); + let target_rb_pos = self + .footer + .index + .iter() + .find_map(|entry| { + if entry.is_within_range(key) { + Some(entry.pos) + } else { + None + } + }) + .ok_or(Error::with_msg(NotFound, "target key not found in sst"))?; + + let accessor = QueryAccessor::Point(*key); + let target_rb = self.target_record_block(target_rb_pos, tx_log_store)?; + + let mut iter = BlockQueryIter::<'_, K, V> { + block: &target_rb, + offset: 0, + accessor: &accessor, + phantom: PhantomData, + }; + + iter.find_map(|(k, v_opt)| if k == *key { v_opt } else { None }) + .ok_or(Error::with_msg(NotFound, "target value not found in SST")) + } + + /// Range query. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn access_range<D: BlockSet + 'static>( + &self, + range_query_ctx: &mut RangeQueryCtx<K, V>, + tx_log_store: &Arc<TxLogStore<D>>, + ) -> Result<()> { + debug_assert!(!range_query_ctx.is_completed()); + let range_uncompleted = range_query_ctx.range_uncompleted().unwrap(); + let target_rbs = self.footer.index.iter().filter_map(|entry| { + if entry.overlap_with(&range_uncompleted) { + Some(entry.pos) + } else { + None + } + }); + + let accessor = QueryAccessor::Range(range_uncompleted.clone()); + for target_rb_pos in target_rbs { + let target_rb = self.target_record_block(target_rb_pos, tx_log_store)?; + + let iter = BlockQueryIter::<'_, K, V> { + block: &target_rb, + offset: 0, + accessor: &accessor, + phantom: PhantomData, + }; + + let targets: Vec<_> = iter + .filter_map(|(k, v_opt)| { + if range_uncompleted.contains(&k) { + Some((k, v_opt.unwrap())) + } else { + None + } + }) + .collect(); + for (target_k, target_v) in targets { + range_query_ctx.complete(target_k, target_v); + } + } + Ok(()) + } + + /// Locate the target record block given its position, it + /// resides in either the cache or the log. + fn target_record_block<D: BlockSet + 'static>( + &self, + target_pos: BlockId, + tx_log_store: &Arc<TxLogStore<D>>, + ) -> Result<Arc<RecordBlock>> { + let mut cache = self.cache.lock(); + if let Some(cached_rb) = cache.get(&target_pos) { + Ok(cached_rb.clone()) + } else { + let mut rb = RecordBlock::from_buf(vec![0; RECORD_BLOCK_SIZE]); + // TODO: Avoid opening the log on every call + let tx_log = tx_log_store.open_log(self.id, false)?; + tx_log.read(target_pos, BufMut::try_from(rb.as_mut_slice()).unwrap())?; + let rb = Arc::new(rb); + cache.put(target_pos, rb.clone()); + Ok(rb) + } + } + + /// Return the iterator over this `SSTable`. + /// The given `event_listener` (optional) is used on dropping records + /// during iteration. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn iter<'a, D: BlockSet + 'static>( + &'a self, + sync_id: SyncId, + discard_unsynced: bool, + tx_log_store: &'a Arc<TxLogStore<D>>, + event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>, + ) -> SstIter<'a, K, V, D> { + let all_synced = sync_id > self.sync_id(); + let accessor = ScanAccessor { + all_synced, + discard_unsynced, + event_listener, + }; + + let first_rb = self + .target_record_block(self.footer.index[0].pos, tx_log_store) + .unwrap(); + + SstIter { + sst: self, + curr_nth_index: 0, + curr_rb_iter: Some(BlockScanIter { + block: first_rb, + offset: 0, + accessor, + }), + tx_log_store, + } + } + + /// Scan the whole SST and collect all records. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn access_scan<D: BlockSet + 'static>( + &self, + sync_id: SyncId, + discard_unsynced: bool, + tx_log_store: &Arc<TxLogStore<D>>, + event_listener: Option<&Arc<dyn TxEventListener<K, V>>>, + ) -> Result<Vec<(K, ValueEx<V>)>> { + let all_records = self + .iter(sync_id, discard_unsynced, tx_log_store, event_listener) + .collect(); + Ok(all_records) + } + + // Building functions below + + /// Builds a SST given a bunch of records, after the SST becomes immutable. + /// The given `event_listener` (optional) is used on adding records. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn build<'a, D: BlockSet + 'static, I, KVex>( + records_iter: I, + sync_id: SyncId, + tx_log: &'a Arc<TxLog<D>>, + event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>, + ) -> Result<Self> + where + I: Iterator<Item = KVex>, + KVex: AsKVex<K, V>, + Self: 'a, + { + let mut cache = LruCache::new(NonZeroUsize::new(Self::CACHE_CAP).unwrap()); + let (total_records, index_vec) = + Self::build_record_blocks(records_iter, tx_log, &mut cache, event_listener)?; + let footer = Self::build_footer::<D>(index_vec, total_records, sync_id, tx_log)?; + + Ok(Self { + id: tx_log.id(), + footer, + cache: Mutex::new(cache), + phantom: PhantomData, + }) + } + + /// Builds all the record blocks from the given records. Put the blocks to the log + /// and the cache. + fn build_record_blocks<'a, D: BlockSet + 'static, I, KVex>( + records_iter: I, + tx_log: &'a TxLog<D>, + cache: &mut LruCache<BlockId, Arc<RecordBlock>>, + event_listener: Option<&'a Arc<dyn TxEventListener<K, V>>>, + ) -> Result<(usize, Vec<IndexEntry<K>>)> + where + I: Iterator<Item = KVex>, + KVex: AsKVex<K, V>, + Self: 'a, + { + let mut index_vec = Vec::new(); + let mut total_records = 0; + let mut pos = 0 as BlockId; + let (mut first_k, mut curr_k) = (None, None); + let mut inner_offset = 0; + + let mut block_buf = Vec::with_capacity(RECORD_BLOCK_SIZE); + for kv_ex in records_iter { + let (key, value_ex) = (*kv_ex.key(), kv_ex.value_ex()); + total_records += 1; + + if inner_offset == 0 { + debug_assert!(block_buf.is_empty()); + let _ = first_k.insert(key); + } + let _ = curr_k.insert(key); + + block_buf.extend_from_slice(key.as_bytes()); + inner_offset += Self::K_SIZE; + + match value_ex { + ValueEx::Synced(v) => { + block_buf.push(RecordFlag::Synced as u8); + block_buf.extend_from_slice(v.as_bytes()); + + if let Some(listener) = event_listener { + listener.on_add_record(&(&key, v))?; + } + inner_offset += 1 + Self::V_SIZE; + } + ValueEx::Unsynced(v) => { + block_buf.push(RecordFlag::Unsynced as u8); + block_buf.extend_from_slice(v.as_bytes()); + + if let Some(listener) = event_listener { + listener.on_add_record(&(&key, v))?; + } + inner_offset += 1 + Self::V_SIZE; + } + ValueEx::SyncedAndUnsynced(sv, usv) => { + block_buf.push(RecordFlag::SyncedAndUnsynced as u8); + block_buf.extend_from_slice(sv.as_bytes()); + block_buf.extend_from_slice(usv.as_bytes()); + + if let Some(listener) = event_listener { + listener.on_add_record(&(&key, sv))?; + listener.on_add_record(&(&key, usv))?; + } + inner_offset += Self::MAX_RECORD_SIZE; + } + } + + let cap_remained = RECORD_BLOCK_SIZE - inner_offset; + if cap_remained >= Self::MAX_RECORD_SIZE { + continue; + } + + let index_entry = IndexEntry { + pos, + first: first_k.unwrap(), + last: key, + }; + build_one_record_block(&index_entry, &mut block_buf, tx_log, cache)?; + index_vec.push(index_entry); + + pos += RECORD_BLOCK_NBLOCKS; + inner_offset = 0; + block_buf.clear(); + } + debug_assert!(total_records > 0); + + if !block_buf.is_empty() { + let last_entry = IndexEntry { + pos, + first: first_k.unwrap(), + last: curr_k.unwrap(), + }; + build_one_record_block(&last_entry, &mut block_buf, tx_log, cache)?; + index_vec.push(last_entry); + } + + fn build_one_record_block<K: RecordKey<K>, D: BlockSet + 'static>( + entry: &IndexEntry<K>, + buf: &mut Vec<u8>, + tx_log: &TxLog<D>, + cache: &mut LruCache<BlockId, Arc<RecordBlock>>, + ) -> Result<()> { + buf.resize(RECORD_BLOCK_SIZE, 0); + let record_block = RecordBlock::from_buf(buf.clone()); + + tx_log.append(BufRef::try_from(record_block.as_slice()).unwrap())?; + cache.put(entry.pos, Arc::new(record_block)); + Ok(()) + } + + Ok((total_records, index_vec)) + } + + /// Builds the footer from the given index entries. The footer block will be appended + /// to the SST log's end. + fn build_footer<'a, D: BlockSet + 'static>( + index_vec: Vec<IndexEntry<K>>, + total_records: usize, + sync_id: SyncId, + tx_log: &'a TxLog<D>, + ) -> Result<Footer<K>> + where + Self: 'a, + { + let footer_buf_len = align_up( + index_vec.len() * Self::INDEX_ENTRY_SIZE + FOOTER_META_SIZE, + BLOCK_SIZE, + ); + let mut append_buf = Vec::with_capacity(footer_buf_len); + for entry in &index_vec { + append_buf.extend_from_slice(&entry.pos.to_le_bytes()); + append_buf.extend_from_slice(entry.first.as_bytes()); + append_buf.extend_from_slice(entry.last.as_bytes()); + } + append_buf.resize(footer_buf_len, 0); + let meta = FooterMeta { + num_index: index_vec.len() as _, + index_nblocks: (footer_buf_len / BLOCK_SIZE) as _, + total_records: total_records as _, + record_block_size: RECORD_BLOCK_SIZE as _, + sync_id, + }; + append_buf[footer_buf_len - FOOTER_META_SIZE..].copy_from_slice(meta.as_bytes()); + tx_log.append(BufRef::try_from(&append_buf[..]).unwrap())?; + + Ok(Footer { + meta, + index: index_vec, + }) + } + + /// Builds a SST from a `TxLog`, loads the footer and the index blocks. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn from_log<D: BlockSet + 'static>(tx_log: &Arc<TxLog<D>>) -> Result<Self> { + let nblocks = tx_log.nblocks(); + + let mut rbuf = Buf::alloc(1)?; + // Load footer block (last block) + tx_log.read(nblocks - 1, rbuf.as_mut())?; + let meta = FooterMeta::from_bytes(&rbuf.as_slice()[BLOCK_SIZE - FOOTER_META_SIZE..]); + + let mut rbuf = Buf::alloc(meta.index_nblocks as _)?; + tx_log.read(nblocks - meta.index_nblocks as usize, rbuf.as_mut())?; + let mut index = Vec::with_capacity(meta.num_index as _); + let mut cache = LruCache::new(NonZeroUsize::new(Self::CACHE_CAP).unwrap()); + let mut record_block = vec![0; RECORD_BLOCK_SIZE]; + for i in 0..meta.num_index as _ { + let buf = + &rbuf.as_slice()[i * Self::INDEX_ENTRY_SIZE..(i + 1) * Self::INDEX_ENTRY_SIZE]; + + let pos = BlockId::from_le_bytes(buf[..BID_SIZE].try_into().unwrap()); + let first = K::from_bytes(&buf[BID_SIZE..BID_SIZE + Self::K_SIZE]); + let last = + K::from_bytes(&buf[Self::INDEX_ENTRY_SIZE - Self::K_SIZE..Self::INDEX_ENTRY_SIZE]); + + tx_log.read(pos, BufMut::try_from(&mut record_block[..]).unwrap())?; + let _ = cache.put(pos, Arc::new(RecordBlock::from_buf(record_block.clone()))); + + index.push(IndexEntry { pos, first, last }) + } + + let footer = Footer { meta, index }; + Ok(Self { + id: tx_log.id(), + footer, + cache: Mutex::new(cache), + phantom: PhantomData, + }) + } +} + +impl<K: RecordKey<K>> IndexEntry<K> { + pub fn range(&self) -> RangeInclusive<K> { + self.first..=self.last + } + + pub fn is_within_range(&self, key: &K) -> bool { + self.range().contains(key) + } + + pub fn overlap_with(&self, rhs_range: &RangeInclusive<K>) -> bool { + let lhs_range = self.range(); + !(lhs_range.end() < rhs_range.start() || lhs_range.start() > rhs_range.end()) + } +} + +impl RecordBlock { + pub fn from_buf(buf: Vec<u8>) -> Self { + debug_assert_eq!(buf.len(), RECORD_BLOCK_SIZE); + Self { buf } + } + + pub fn as_slice(&self) -> &[u8] { + &self.buf + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + &mut self.buf + } +} + +impl<K: RecordKey<K>> QueryAccessor<K> { + pub fn hit_target(&self, target: &K) -> bool { + match self { + QueryAccessor::Point(k) => k == target, + QueryAccessor::Range(range) => range.contains(target), + } + } +} + +impl<K: RecordKey<K>, V: RecordValue> Iterator for BlockQueryIter<'_, K, V> { + type Item = (K, Option<V>); + + fn next(&mut self) -> Option<Self::Item> { + let mut offset = self.offset; + let buf_slice = &self.block.buf; + let (k_size, v_size) = (SSTable::<K, V>::K_SIZE, SSTable::<K, V>::V_SIZE); + + if offset + SSTable::<K, V>::MIN_RECORD_SIZE > RECORD_BLOCK_SIZE { + return None; + } + + let key = K::from_bytes(&buf_slice[offset..offset + k_size]); + offset += k_size; + + let flag = RecordFlag::from(buf_slice[offset]); + offset += 1; + if flag == RecordFlag::Invalid { + return None; + } + + let hit_target = self.accessor.hit_target(&key); + let value_opt = match flag { + RecordFlag::Synced | RecordFlag::Unsynced => { + let v_opt = if hit_target { + Some(V::from_bytes(&buf_slice[offset..offset + v_size])) + } else { + None + }; + offset += v_size; + v_opt + } + RecordFlag::SyncedAndUnsynced => { + let v_opt = if hit_target { + Some(V::from_bytes( + &buf_slice[offset + v_size..offset + 2 * v_size], + )) + } else { + None + }; + offset += 2 * v_size; + v_opt + } + _ => unreachable!(), + }; + + self.offset = offset; + Some((key, value_opt)) + } +} + +impl<K: RecordKey<K>, V: RecordValue> Iterator for BlockScanIter<'_, K, V> { + type Item = (K, ValueEx<V>); + + fn next(&mut self) -> Option<Self::Item> { + let mut offset = self.offset; + let buf_slice = &self.block.buf; + let (k_size, v_size) = (SSTable::<K, V>::K_SIZE, SSTable::<K, V>::V_SIZE); + let (all_synced, discard_unsynced, event_listener) = ( + self.accessor.all_synced, + self.accessor.discard_unsynced, + &self.accessor.event_listener, + ); + + let (key, value_ex) = loop { + if offset + SSTable::<K, V>::MIN_RECORD_SIZE > RECORD_BLOCK_SIZE { + return None; + } + + let key = K::from_bytes(&buf_slice[offset..offset + k_size]); + offset += k_size; + + let flag = RecordFlag::from(buf_slice[offset]); + offset += 1; + if flag == RecordFlag::Invalid { + return None; + } + + let v_ex = match flag { + RecordFlag::Synced => { + let v = V::from_bytes(&buf_slice[offset..offset + v_size]); + offset += v_size; + ValueEx::Synced(v) + } + RecordFlag::Unsynced => { + let v = V::from_bytes(&buf_slice[offset..offset + v_size]); + offset += v_size; + if all_synced { + ValueEx::Synced(v) + } else if discard_unsynced { + if let Some(listener) = event_listener { + listener.on_drop_record(&(key, v)).unwrap(); + } + continue; + } else { + ValueEx::Unsynced(v) + } + } + RecordFlag::SyncedAndUnsynced => { + let sv = V::from_bytes(&buf_slice[offset..offset + v_size]); + offset += v_size; + let usv = V::from_bytes(&buf_slice[offset..offset + v_size]); + offset += v_size; + if all_synced { + if let Some(listener) = event_listener { + listener.on_drop_record(&(key, sv)).unwrap(); + } + ValueEx::Synced(usv) + } else if discard_unsynced { + if let Some(listener) = event_listener { + listener.on_drop_record(&(key, usv)).unwrap(); + } + ValueEx::Synced(sv) + } else { + ValueEx::SyncedAndUnsynced(sv, usv) + } + } + _ => unreachable!(), + }; + break (key, v_ex); + }; + + self.offset = offset; + Some((key, value_ex)) + } +} + +impl<K: RecordKey<K>, V: RecordValue, D: BlockSet + 'static> Iterator for SstIter<'_, K, V, D> { + type Item = (K, ValueEx<V>); + + fn next(&mut self) -> Option<Self::Item> { + // Iterate over the current record block first + if let Some(next) = self.curr_rb_iter.as_mut().unwrap().next() { + return Some(next); + } + + let curr_rb_iter = self.curr_rb_iter.take().unwrap(); + + self.curr_nth_index += 1; + // Iteration goes to the end + if self.curr_nth_index >= self.sst.footer.meta.num_index as _ { + return None; + } + + // Ready to iterate the next record block + let next_pos = self.sst.footer.index[self.curr_nth_index].pos; + let next_rb = self + .sst + .target_record_block(next_pos, self.tx_log_store) + .unwrap(); + + let mut next_rb_iter = BlockScanIter { + block: next_rb, + offset: 0, + accessor: curr_rb_iter.accessor, + }; + let next = next_rb_iter.next()?; + + let _ = self.curr_rb_iter.insert(next_rb_iter); + Some(next) + } +} + +impl<K: Debug, V> Debug for SSTable<K, V> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SSTable") + .field("id", &self.id) + .field("footer", &self.footer.meta) + .field( + "range", + &RangeInclusive::new( + &self.footer.index[0].first, + &self.footer.index[self.footer.meta.num_index as usize - 1].last, + ), + ) + .finish() + } +} + +/// Flag bit for records in SSTable. +#[derive(PartialEq, Eq, Debug)] +#[repr(u8)] +enum RecordFlag { + Synced = 7, + Unsynced = 11, + SyncedAndUnsynced = 19, + Invalid, +} + +impl From<u8> for RecordFlag { + fn from(value: u8) -> Self { + match value { + 7 => RecordFlag::Synced, + 11 => RecordFlag::Unsynced, + 19 => RecordFlag::SyncedAndUnsynced, + _ => RecordFlag::Invalid, + } + } +} diff --git a/kernel/comps/mlsdisk/src/layers/4-lsm/tx_lsm_tree.rs b/kernel/comps/mlsdisk/src/layers/4-lsm/tx_lsm_tree.rs new file mode 100644 index 00000000..28daf37a --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/4-lsm/tx_lsm_tree.rs @@ -0,0 +1,1037 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Transactional LSM-Tree. +//! +//! API: `format()`, `recover()`, `get()`, `put()`, `get_range()`, `sync()` +//! +//! Responsible for managing two `MemTable`s, WAL and SSTs as `TxLog`s +//! backed by a `TxLogStore`. All operations are executed based +//! on internal transactions. +use alloc::vec; +use core::{ + hash::Hash, + ops::{Add, RangeInclusive, Sub}, + sync::atomic::{AtomicU64, Ordering}, +}; + +use ostd_pod::Pod; + +use super::{ + compaction::Compactor, + mem_table::{MemTableManager, ValueEx}, + range_query_ctx::RangeQueryCtx, + sstable::SSTable, + wal::{WalAppendTx, BUCKET_WAL}, +}; +use crate::{ + layers::{ + bio::BlockSet, + log::{TxLogId, TxLogStore}, + }, + os::{spawn, BTreeMap, RwLock}, + prelude::*, + tx::CurrentTx, +}; + +/// Monotonic incrementing sync ID. +pub type SyncId = u64; + +/// A transactional LSM-Tree, managing `MemTable`s, WALs and SSTs backed by `TxLogStore` (L3). +/// +/// Supports inserting and querying key-value records within transactions. +/// Supports user-defined callbacks in `MemTable`, during compaction and recovery. +pub struct TxLsmTree<K: RecordKey<K>, V, D>(Arc<TreeInner<K, V, D>>); + +/// Inner structures of `TxLsmTree`. +pub(super) struct TreeInner<K: RecordKey<K>, V, D> { + memtable_manager: MemTableManager<K, V>, + sst_manager: RwLock<SstManager<K, V>>, + wal_append_tx: WalAppendTx<D>, + compactor: Compactor<K, V>, + tx_log_store: Arc<TxLogStore<D>>, + listener_factory: Arc<dyn TxEventListenerFactory<K, V>>, + master_sync_id: MasterSyncId, +} + +/// Levels in a `TxLsmTree`. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum LsmLevel { + L0 = 0, + L1, + L2, + L3, + L4, + L5, // Cover over 10 TB data (over 800 TB user data from L5) +} + +/// Manager of all `SSTable`s from every level in a `TxLsmTree`. +#[derive(Debug)] +struct SstManager<K, V> { + level_ssts: Vec<BTreeMap<TxLogId, Arc<SSTable<K, V>>>>, +} + +/// A factory of per-transaction event listeners. +pub trait TxEventListenerFactory<K, V>: Send + Sync { + /// Creates a new event listener for a given transaction. + fn new_event_listener(&self, tx_type: TxType) -> Arc<dyn TxEventListener<K, V>>; +} + +/// An event listener that get informed when +/// 1) A new record is added, +/// 2) An existing record is dropped, +/// 3) After a TX began, +/// 4) Before a TX ended, +/// 5) After a TX committed. +/// +/// `tx_type` indicates an internal transaction of `TxLsmTree`. +pub trait TxEventListener<K, V> { + /// Notify the listener that a new record is added to a LSM-Tree. + fn on_add_record(&self, record: &dyn AsKV<K, V>) -> Result<()>; + + /// Notify the listener that an existing record is dropped from a LSM-Tree. + fn on_drop_record(&self, record: &dyn AsKV<K, V>) -> Result<()>; + + /// Notify the listener after a TX just began. + fn on_tx_begin(&self, tx: &mut CurrentTx<'_>) -> Result<()>; + + /// Notify the listener before a tx ended. + fn on_tx_precommit(&self, tx: &mut CurrentTx<'_>) -> Result<()>; + + /// Notify the listener after a TX committed. + fn on_tx_commit(&self); +} + +/// Types of `TxLsmTree`'s internal transactions. +#[derive(Copy, Clone, Debug)] +pub enum TxType { + /// A Compaction Transaction merges old `SSTable`s into new ones. + Compaction { to_level: LsmLevel }, + /// A Migration Transaction migrates synced records from old `SSTable`s + /// new ones and discard unsynced records. + Migration, +} + +/// A trusted store that stores the master sync ID. +pub trait SyncIdStore: Send + Sync { + /// Read the current master sync ID from the store. + fn read(&self) -> Result<SyncId>; + + /// Write the given master sync ID to the store. + fn write(&self, id: SyncId) -> Result<()>; +} + +/// Master sync ID to help `TxLsmTree` achieve sync awareness. +pub(super) struct MasterSyncId { + id: AtomicU64, + store: Option<Arc<dyn SyncIdStore>>, +} + +/// A trait that represents the key for a record in a `TxLsmTree`. +pub trait RecordKey<K>: + Ord + Pod + Hash + Add<usize, Output = K> + Sub<K, Output = usize> + Debug + Send + Sync + 'static +{ +} +/// A trait that represents the value for a record in a `TxLsmTree`. +pub trait RecordValue: Pod + Debug + Send + Sync + 'static {} + +/// Represent any type that includes a key and a value. +pub trait AsKV<K, V>: Send + Sync { + fn key(&self) -> &K; + + fn value(&self) -> &V; +} + +/// A callback that will be called when drops a record. +pub type OnDropRecodeFn<K, V> = dyn Fn(&dyn AsKV<K, V>) + Send + Sync; + +/// Represent any type that includes a key and a sync-aware value. +pub(super) trait AsKVex<K, V> { + fn key(&self) -> &K; + + fn value_ex(&self) -> &ValueEx<V>; +} + +/// Capacity of each `MemTable` and `SSTable`. +pub(super) const MEMTABLE_CAPACITY: usize = 2097152; // 96 MiB MemTable, cover 8 GiB data // TBD +pub(super) const SSTABLE_CAPACITY: usize = MEMTABLE_CAPACITY; + +impl<K: RecordKey<K>, V: RecordValue, D: BlockSet + 'static> TxLsmTree<K, V, D> { + /// Format a `TxLsmTree` from a given `TxLogStore`. + pub fn format( + tx_log_store: Arc<TxLogStore<D>>, + listener_factory: Arc<dyn TxEventListenerFactory<K, V>>, + on_drop_record_in_memtable: Option<Arc<OnDropRecodeFn<K, V>>>, + sync_id_store: Option<Arc<dyn SyncIdStore>>, + ) -> Result<Self> { + let inner = TreeInner::format( + tx_log_store, + listener_factory, + on_drop_record_in_memtable, + sync_id_store, + )?; + Ok(Self(Arc::new(inner))) + } + + /// Recover a `TxLsmTree` from a given `TxLogStore`. + pub fn recover( + tx_log_store: Arc<TxLogStore<D>>, + listener_factory: Arc<dyn TxEventListenerFactory<K, V>>, + on_drop_record_in_memtable: Option<Arc<OnDropRecodeFn<K, V>>>, + sync_id_store: Option<Arc<dyn SyncIdStore>>, + ) -> Result<Self> { + let inner = TreeInner::recover( + tx_log_store, + listener_factory, + on_drop_record_in_memtable, + sync_id_store, + )?; + Ok(Self(Arc::new(inner))) + } + + /// Gets a target value given a key. + pub fn get(&self, key: &K) -> Result<V> { + self.0.get(key) + } + + /// Gets a range of target values given a range of keys. + pub fn get_range(&self, range_query_ctx: &mut RangeQueryCtx<K, V>) -> Result<()> { + self.0.get_range(range_query_ctx) + } + + /// Puts a key-value record to the tree. + pub fn put(&self, key: K, value: V) -> Result<()> { + let inner = &self.0; + let record = (key, value); + + // Write the record to WAL + inner.wal_append_tx.append(&record)?; + + // Put the record into `MemTable` + let at_capacity = inner.memtable_manager.put(key, value); + if !at_capacity { + return Ok(()); + } + + // Commit WAL TX before compaction + // TODO: Error handling: try twice or ignore + let wal_id = inner.wal_append_tx.commit()?; + + // Wait asynchronous compaction to finish + // TODO: Error handling for compaction: try twice or become read-only + inner.compactor.wait_compaction()?; + + inner.memtable_manager.switch().unwrap(); + + // Trigger compaction when `MemTable` is at capacity + self.do_compaction_tx(wal_id) + } + + /// Persist all in-memory data of `TxLsmTree` to the backed storage. + pub fn sync(&self) -> Result<()> { + self.0.sync() + } + + /// Do a compaction TX. + /// The given `wal_id` is used to identify the WAL for discarding. + fn do_compaction_tx(&self, wal_id: TxLogId) -> Result<()> { + let inner = self.0.clone(); + let handle = spawn(move || -> Result<()> { + // Do major compaction first if necessary + if inner + .sst_manager + .read() + .require_major_compaction(LsmLevel::L0) + { + inner.do_major_compaction(LsmLevel::L1)?; + } + + // Do minor compaction + inner.do_minor_compaction(wal_id)?; + + Ok(()) + }); + + // handle.join().unwrap()?; // synchronous + self.0.compactor.record_handle(handle); // asynchronous + Ok(()) + } +} + +impl<K: RecordKey<K>, V: RecordValue, D: BlockSet + 'static> TreeInner<K, V, D> { + pub fn format( + tx_log_store: Arc<TxLogStore<D>>, + listener_factory: Arc<dyn TxEventListenerFactory<K, V>>, + on_drop_record_in_memtable: Option<Arc<OnDropRecodeFn<K, V>>>, + sync_id_store: Option<Arc<dyn SyncIdStore>>, + ) -> Result<Self> { + let sync_id: SyncId = 0; + Ok(Self { + memtable_manager: MemTableManager::new( + sync_id, + MEMTABLE_CAPACITY, + on_drop_record_in_memtable, + ), + sst_manager: RwLock::new(SstManager::new()), + wal_append_tx: WalAppendTx::new(&tx_log_store, sync_id), + compactor: Compactor::new(), + tx_log_store, + listener_factory, + master_sync_id: MasterSyncId::new(sync_id_store, sync_id)?, + }) + } + + pub fn recover( + tx_log_store: Arc<TxLogStore<D>>, + listener_factory: Arc<dyn TxEventListenerFactory<K, V>>, + on_drop_record_in_memtable: Option<Arc<OnDropRecodeFn<K, V>>>, + sync_id_store: Option<Arc<dyn SyncIdStore>>, + ) -> Result<Self> { + let (synced_records, wal_sync_id) = Self::recover_from_wal(&tx_log_store)?; + let (sst_manager, ssts_sync_id) = Self::recover_sst_manager(&tx_log_store)?; + + let max_sync_id = wal_sync_id.max(ssts_sync_id); + let master_sync_id = MasterSyncId::new(sync_id_store, max_sync_id)?; + let sync_id = master_sync_id.id(); + + let memtable_manager = Self::recover_memtable_manager( + sync_id, + synced_records.into_iter(), + on_drop_record_in_memtable, + ); + + let recov_self = Self { + memtable_manager, + sst_manager: RwLock::new(sst_manager), + wal_append_tx: WalAppendTx::new(&tx_log_store, sync_id), + compactor: Compactor::new(), + tx_log_store, + listener_factory, + master_sync_id, + }; + + recov_self.do_migration_tx()?; + + debug!("[SwornDisk TxLsmTree] Recovery completed: {recov_self:?}"); + Ok(recov_self) + } + + /// Recover the synced records and the maximum sync ID from the latest WAL. + fn recover_from_wal(tx_log_store: &Arc<TxLogStore<D>>) -> Result<(Vec<(K, V)>, SyncId)> { + let tx = tx_log_store.new_tx(); + let res: Result<_> = tx.context(|| { + let wal_res = tx_log_store.open_log_in(BUCKET_WAL); + if let Err(e) = &wal_res + && e.errno() == NotFound + { + return Ok((vec![], 0)); + } + let wal = wal_res?; + // Only synced records count, all unsynced are discarded + WalAppendTx::collect_synced_records_and_sync_id::<K, V>(&wal) + }); + if res.is_ok() { + tx.commit()?; + } else { + tx.abort(); + return_errno_with_msg!(TxAborted, "recover from WAL failed"); + } + res + } + + /// Recover `MemTable` from the given synced records. + fn recover_memtable_manager( + sync_id: SyncId, + synced_records: impl Iterator<Item = (K, V)>, + on_drop_record_in_memtable: Option<Arc<OnDropRecodeFn<K, V>>>, + ) -> MemTableManager<K, V> { + let memtable_manager = + MemTableManager::new(sync_id, MEMTABLE_CAPACITY, on_drop_record_in_memtable); + synced_records.into_iter().for_each(|(k, v)| { + let _ = memtable_manager.put(k, v); + }); + memtable_manager + } + + /// Recover `SSTable`s from the given log store. + /// Return the recovered `SstManager` and the maximum sync ID present. + fn recover_sst_manager( + tx_log_store: &Arc<TxLogStore<D>>, + ) -> Result<(SstManager<K, V>, SyncId)> { + let mut manager = SstManager::new(); + let mut max_sync_id: SyncId = 0; + let tx = tx_log_store.new_tx(); + let res: Result<_> = tx.context(|| { + for (level, bucket) in LsmLevel::iter() { + let log_ids = tx_log_store.list_logs_in(bucket); + if let Err(e) = &log_ids + && e.errno() == NotFound + { + continue; + } + + for id in log_ids? { + let log = tx_log_store.open_log(id, false)?; + let sst = SSTable::<K, V>::from_log(&log)?; + max_sync_id = max_sync_id.max(sst.sync_id()); + manager.insert(SSTable::from_log(&log)?, level); + } + } + Ok(()) + }); + if res.is_ok() { + tx.commit()?; + } else { + tx.abort(); + return_errno_with_msg!(TxAborted, "recover TxLsmTree failed"); + } + Ok((manager, max_sync_id)) + } + + pub fn get(&self, key: &K) -> Result<V> { + // 1. Search from MemTables + if let Some(value) = self.memtable_manager.get(key) { + return Ok(value); + } + + // 2. Search from SSTs (do Read TX) + let value = self.do_read_tx(key)?; + + Ok(value) + } + + pub fn get_range(&self, range_query_ctx: &mut RangeQueryCtx<K, V>) -> Result<()> { + let is_completed = self.memtable_manager.get_range(range_query_ctx); + if is_completed { + return Ok(()); + } + + self.do_read_range_tx(range_query_ctx)?; + + Ok(()) + } + + pub fn sync(&self) -> Result<()> { + let master_sync_id = self.master_sync_id.id() + 1; + + // Wait asynchronous compaction to finish + // TODO: Error handling for compaction: try twice or become read-only + self.compactor.wait_compaction()?; + + // TODO: Error handling for WAL: try twice or become read-only + self.wal_append_tx.sync(master_sync_id)?; + + self.memtable_manager.sync(master_sync_id); + + // TODO: Error handling: try twice or ignore + self.master_sync_id.increment()?; + Ok(()) + } + + /// Read TX. + fn do_read_tx(&self, key: &K) -> Result<V> { + let tx = self.tx_log_store.new_tx(); + + let read_res: Result<_> = tx.context(|| { + // Search each level from top to bottom (newer to older) + let sst_manager = self.sst_manager.read(); + + for (level, _bucket) in LsmLevel::iter() { + for (_id, sst) in sst_manager.list_level(level) { + if !sst.is_within_range(key) { + continue; + } + + if let Ok(target_value) = sst.access_point(key, &self.tx_log_store) { + return Ok(target_value); + } + } + } + + return_errno_with_msg!(NotFound, "target sst not found"); + }); + if read_res.as_ref().is_err_and(|e| e.errno() != NotFound) { + tx.abort(); + return_errno_with_msg!(TxAborted, "read TX failed") + } + + tx.commit()?; + + read_res + } + + /// Read Range TX. + fn do_read_range_tx(&self, range_query_ctx: &mut RangeQueryCtx<K, V>) -> Result<()> { + debug_assert!(!range_query_ctx.is_completed()); + let tx = self.tx_log_store.new_tx(); + + let read_res: Result<_> = tx.context(|| { + // Search each level from top to bottom (newer to older) + let sst_manager = self.sst_manager.read(); + for (level, _bucket) in LsmLevel::iter() { + for (_id, sst) in sst_manager.list_level(level) { + if !sst.overlap_with(&range_query_ctx.range_uncompleted().unwrap()) { + continue; + } + + sst.access_range(range_query_ctx, &self.tx_log_store)?; + + if range_query_ctx.is_completed() { + return Ok(()); + } + } + } + + return_errno_with_msg!(NotFound, "target sst not found"); + }); + if read_res.as_ref().is_err_and(|e| e.errno() != NotFound) { + tx.abort(); + return_errno_with_msg!(TxAborted, "read TX failed") + } + + tx.commit()?; + + read_res + } + + /// Minor Compaction TX { to_level: LsmLevel::L0 }. + fn do_minor_compaction(&self, wal_id: TxLogId) -> Result<()> { + let mut tx = self.tx_log_store.new_tx(); + // Prepare TX listener + let tx_type = TxType::Compaction { + to_level: LsmLevel::L0, + }; + let event_listener = self.listener_factory.new_event_listener(tx_type); + event_listener.on_tx_begin(&mut tx).map_err(|_| { + tx.abort(); + Error::with_msg( + TxAborted, + "minor compaction TX callback 'on_tx_begin' failed", + ) + })?; + + let res: Result<_> = tx.context(|| { + let tx_log = self.tx_log_store.create_log(LsmLevel::L0.bucket())?; + + // Cook records in immutable MemTable into a new SST + let immutable_memtable = self.memtable_manager.immutable_memtable(); + let records_iter = immutable_memtable.iter(); + let sync_id = immutable_memtable.sync_id(); + + let sst = SSTable::build(records_iter, sync_id, &tx_log, Some(&event_listener))?; + self.tx_log_store.delete_log(wal_id)?; + Ok(sst) + }); + let new_sst = res.map_err(|_| { + tx.abort(); + Error::with_msg(TxAborted, "minor compaction TX failed") + })?; + + event_listener.on_tx_precommit(&mut tx).map_err(|_| { + tx.abort(); + Error::with_msg( + TxAborted, + "minor compaction TX callback 'on_tx_precommit' failed", + ) + })?; + + tx.commit()?; + event_listener.on_tx_commit(); + + self.sst_manager.write().insert(new_sst, LsmLevel::L0); + + debug!("[SwornDisk TxLsmTree] Minor Compaction completed: {self:?}"); + Ok(()) + } + + /// Major Compaction TX { to_level: LsmLevel::L1~LsmLevel::L5 }. + fn do_major_compaction(&self, to_level: LsmLevel) -> Result<()> { + let from_level = to_level.upper_level(); + let mut tx = self.tx_log_store.new_tx(); + + // Prepare TX listener + let tx_type = TxType::Compaction { to_level }; + let event_listener = self.listener_factory.new_event_listener(tx_type); + event_listener.on_tx_begin(&mut tx).map_err(|_| { + tx.abort(); + Error::with_msg( + TxAborted, + "major compaction TX callback 'on_tx_begin' failed", + ) + })?; + + let master_sync_id = self.master_sync_id.id(); + let tx_log_store = self.tx_log_store.clone(); + let listener = event_listener.clone(); + let res: Result<_> = tx.context(move || { + let (mut created_ssts, mut deleted_ssts) = (vec![], vec![]); + + // Collect overlapped SSTs + let sst_manager = self.sst_manager.read(); + let (upper_sst_id, upper_sst) = sst_manager + .list_level(from_level) + .last() // Choose the oldest SST from upper level + .map(|(id, sst)| (*id, sst.clone())) + .unwrap(); + let lower_ssts: Vec<(TxLogId, Arc<SSTable<K, V>>)> = { + let mut ssts = sst_manager + .find_overlapped_ssts(&upper_sst.range(), to_level) + .map(|(id, sst)| (*id, sst.clone())) + .collect::<Vec<_>>(); + ssts.sort_by_key(|(_, sst)| *sst.range().start()); + ssts + }; + drop(sst_manager); + + // If there are no overlapped SSTs, just move the upper SST to the lower level + if lower_ssts.is_empty() { + tx_log_store.move_log(upper_sst_id, from_level.bucket(), to_level.bucket())?; + self.sst_manager + .write() + .move_sst(upper_sst_id, from_level, to_level); + return Ok((created_ssts, deleted_ssts)); + } + + let upper_records_iter = + upper_sst.iter(master_sync_id, false, &tx_log_store, Some(&listener)); + + let lower_records_iter = lower_ssts.iter().flat_map(|(_, sst)| { + sst.iter(master_sync_id, false, &tx_log_store, Some(&listener)) + }); + + // Compact records then build new SSTs + created_ssts = Compactor::compact_records_and_build_ssts( + upper_records_iter, + lower_records_iter, + &tx_log_store, + &listener, + to_level, + master_sync_id, + )?; + + // Delete the old SSTs + for (id, level) in core::iter::once((upper_sst_id, from_level)) + .chain(lower_ssts.into_iter().map(|(id, _)| (id, to_level))) + { + tx_log_store.delete_log(id)?; + deleted_ssts.push((id, level)); + } + Ok((created_ssts, deleted_ssts)) + }); + let (created_ssts, deleted_ssts) = res.map_err(|_| { + tx.abort(); + Error::with_msg(TxAborted, "major compaction TX failed") + })?; + + event_listener.on_tx_precommit(&mut tx).map_err(|_| { + tx.abort(); + Error::with_msg( + TxAborted, + "major compaction TX callback 'on_tx_precommit' failed", + ) + })?; + + tx.commit()?; + event_listener.on_tx_commit(); + + self.update_sst_manager( + created_ssts.into_iter().map(|sst| (sst, to_level)), + deleted_ssts.into_iter(), + ); + + debug!("[SwornDisk TxLsmTree] Major Compaction completed: {self:?}"); + + // Continue to do major compaction if necessary + if self.sst_manager.read().require_major_compaction(to_level) { + self.do_major_compaction(to_level.lower_level())?; + } + Ok(()) + } + + /// Migration TX, primarily to discard all unsynced records in SSTs. + fn do_migration_tx(&self) -> Result<()> { + let mut tx = self.tx_log_store.new_tx(); + + // Prepare TX listener + let tx_type = TxType::Migration; + let event_listener = self.listener_factory.new_event_listener(tx_type); + event_listener.on_tx_begin(&mut tx).map_err(|_| { + tx.abort(); + Error::with_msg(TxAborted, "migration TX callback 'on_tx_begin' failed") + })?; + + let master_sync_id = self.master_sync_id.id(); + let tx_log_store = self.tx_log_store.clone(); + let listener = event_listener.clone(); + let res: Result<_> = tx.context(move || { + let (mut created_ssts, mut deleted_ssts) = (vec![], vec![]); + + let sst_manager = self.sst_manager.read(); + for (level, bucket) in LsmLevel::iter() { + let ssts = sst_manager.list_level(level); + // Iterate SSTs whose sync ID is equal to the + // master sync ID, who may have unsynced records + for (&id, sst) in ssts.filter(|(_, sst)| sst.sync_id() == master_sync_id) { + // Collect synced records only + let mut synced_records_iter = sst + .iter(master_sync_id, true, &tx_log_store, Some(&listener)) + .peekable(); + + if synced_records_iter.peek().is_some() { + // Create new migrated SST + let new_log = tx_log_store.create_log(bucket)?; + let new_sst = + SSTable::build(synced_records_iter, master_sync_id, &new_log, None)?; + created_ssts.push((new_sst, level)); + continue; + } + + // Delete the old SST + tx_log_store.delete_log(id)?; + deleted_ssts.push((id, level)); + } + } + + Ok((created_ssts, deleted_ssts)) + }); + let (created_ssts, deleted_ssts) = res.map_err(|_| { + tx.abort(); + Error::with_msg(TxAborted, "migration TX failed") + })?; + + event_listener.on_tx_precommit(&mut tx).map_err(|_| { + tx.abort(); + Error::with_msg(TxAborted, "migration TX callback 'on_tx_precommit' failed") + })?; + + tx.commit()?; + event_listener.on_tx_commit(); + + self.update_sst_manager(created_ssts.into_iter(), deleted_ssts.into_iter()); + Ok(()) + } + + fn update_sst_manager( + &self, + created: impl Iterator<Item = (SSTable<K, V>, LsmLevel)>, + deleted: impl Iterator<Item = (TxLogId, LsmLevel)>, + ) { + let mut sst_manager = self.sst_manager.write(); + created.for_each(|(sst, level)| { + let _ = sst_manager.insert(sst, level); + }); + deleted.for_each(|(id, level)| { + let _ = sst_manager.remove(id, level); + }); + } +} + +impl MasterSyncId { + /// Create a new instance of `MasterSyncId`. + /// Load the master sync ID from the given store if present. + /// If the store is not present, use the default sync ID instead. + pub fn new(store: Option<Arc<dyn SyncIdStore>>, default: SyncId) -> Result<Self> { + let id: SyncId = if let Some(store) = &store { + store.read()? + } else { + default + }; + Ok(Self { + id: AtomicU64::new(id), + store, + }) + } + + /// Get the current master sync ID. + pub fn id(&self) -> SyncId { + self.id.load(Ordering::Acquire) + } + + /// Increment the current master sync ID, + /// store the new ID to the store if present. + /// + /// On success, return the new master sync ID. + pub(super) fn increment(&self) -> Result<SyncId> { + let incremented_id = self.id.fetch_add(1, Ordering::Release) + 1; + if let Some(store) = &self.store { + store.write(incremented_id)?; + } + Ok(incremented_id) + } +} + +impl<K: RecordKey<K>, V, D> Drop for TreeInner<K, V, D> { + fn drop(&mut self) { + // TODO: Should we commit the WAL TX before dropping? + // let _ = self.wal_append_tx.commit(); + } +} + +impl<K: RecordKey<K>, V: RecordValue, D: BlockSet + 'static> Debug for TreeInner<K, V, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TxLsmTree") + .field("memtable_manager", &self.memtable_manager) + .field("sst_manager", &self.sst_manager.read()) + .field("tx_log_store", &self.tx_log_store) + .finish() + } +} + +impl<K: RecordKey<K>, V: RecordValue, D: BlockSet + 'static> Debug for TxLsmTree<K, V, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl LsmLevel { + const LEVEL0_RATIO: u16 = 4; + const LEVELI_RATIO: u16 = 10; + + const MAX_NUM_LEVELS: usize = 6; + const LEVEL_BUCKETS: [(LsmLevel, &'static str); Self::MAX_NUM_LEVELS] = [ + (LsmLevel::L0, LsmLevel::L0.bucket()), + (LsmLevel::L1, LsmLevel::L1.bucket()), + (LsmLevel::L2, LsmLevel::L2.bucket()), + (LsmLevel::L3, LsmLevel::L3.bucket()), + (LsmLevel::L4, LsmLevel::L4.bucket()), + (LsmLevel::L5, LsmLevel::L5.bucket()), + ]; + + pub fn iter() -> impl Iterator<Item = (LsmLevel, &'static str)> { + Self::LEVEL_BUCKETS.iter().cloned() + } + + pub fn upper_level(&self) -> LsmLevel { + debug_assert!(*self != LsmLevel::L0); + LsmLevel::from(*self as u8 - 1) + } + + pub fn lower_level(&self) -> LsmLevel { + debug_assert!(*self != LsmLevel::L5); + LsmLevel::from(*self as u8 + 1) + } + + pub const fn bucket(&self) -> &str { + match self { + LsmLevel::L0 => "L0", + LsmLevel::L1 => "L1", + LsmLevel::L2 => "L2", + LsmLevel::L3 => "L3", + LsmLevel::L4 => "L4", + LsmLevel::L5 => "L5", + } + } +} + +impl From<u8> for LsmLevel { + fn from(value: u8) -> Self { + match value { + 0 => LsmLevel::L0, + 1 => LsmLevel::L1, + 2 => LsmLevel::L2, + 3 => LsmLevel::L3, + 4 => LsmLevel::L4, + 5 => LsmLevel::L5, + _ => unreachable!(), + } + } +} + +impl<K: RecordKey<K>, V: RecordValue> SstManager<K, V> { + pub fn new() -> Self { + let level_ssts = (0..LsmLevel::MAX_NUM_LEVELS) + .map(|_| BTreeMap::new()) + .collect(); + Self { level_ssts } + } + + /// List all SSTs of a given level from newer to older. + pub fn list_level( + &self, + level: LsmLevel, + ) -> impl Iterator<Item = (&TxLogId, &Arc<SSTable<K, V>>)> { + self.level_ssts[level as usize].iter().rev() + } + + pub fn insert(&mut self, sst: SSTable<K, V>, level: LsmLevel) -> Option<Arc<SSTable<K, V>>> { + let nth_level = level as usize; + debug_assert!(nth_level < self.level_ssts.len()); + let level_ssts = &mut self.level_ssts[nth_level]; + level_ssts.insert(sst.id(), Arc::new(sst)) + } + + pub fn remove(&mut self, id: TxLogId, level: LsmLevel) -> Option<Arc<SSTable<K, V>>> { + let level_ssts = &mut self.level_ssts[level as usize]; + level_ssts.remove(&id) + } + + pub fn move_sst(&mut self, id: TxLogId, from: LsmLevel, to: LsmLevel) { + let moved = self.level_ssts[from as usize].remove(&id).unwrap(); + let _ = self.level_ssts[to as usize].insert(id, moved); + } + + /// Find overlapping SSTs at a given level with a given range. + pub fn find_overlapped_ssts<'a>( + &'a self, + range: &'a RangeInclusive<K>, + level: LsmLevel, + ) -> impl Iterator<Item = (&'a TxLogId, &'a Arc<SSTable<K, V>>)> { + self.list_level(level) + .filter(|(_, sst)| sst.overlap_with(range)) + } + + /// Check whether a major compaction is required from `from_level` to its lower level. + pub fn require_major_compaction(&self, from_level: LsmLevel) -> bool { + debug_assert!(from_level != LsmLevel::L5); + + if from_level == LsmLevel::L0 { + return self.level_ssts[from_level as usize].len() >= LsmLevel::LEVEL0_RATIO as _; + } + self.level_ssts[from_level as usize].len() + >= LsmLevel::LEVELI_RATIO.pow(from_level as _) as _ + } +} + +impl<K: Send + Sync, V: Send + Sync> AsKV<K, V> for (K, V) { + fn key(&self) -> &K { + &self.0 + } + fn value(&self) -> &V { + &self.1 + } +} + +impl<K: Send + Sync, V: Send + Sync> AsKV<K, V> for (&K, &V) { + fn key(&self) -> &K { + self.0 + } + fn value(&self) -> &V { + self.1 + } +} + +impl<K: Send + Sync, V: Send + Sync> AsKVex<K, V> for (K, ValueEx<V>) { + fn key(&self) -> &K { + &self.0 + } + fn value_ex(&self) -> &ValueEx<V> { + &self.1 + } +} + +impl<K: Send + Sync, V: Send + Sync> AsKVex<K, V> for (&K, &ValueEx<V>) { + fn key(&self) -> &K { + self.0 + } + fn value_ex(&self) -> &ValueEx<V> { + self.1 + } +} + +#[cfg(test)] +mod tests { + use super::{super::RangeQueryCtx, *}; + use crate::{ + layers::bio::{Buf, MemDisk}, + os::{AeadKey as Key, AeadMac as Mac}, + }; + + struct Factory; + struct Listener; + + impl<K, V> TxEventListenerFactory<K, V> for Factory { + fn new_event_listener(&self, _tx_type: TxType) -> Arc<dyn TxEventListener<K, V>> { + Arc::new(Listener) + } + } + impl<K, V> TxEventListener<K, V> for Listener { + fn on_add_record(&self, _record: &dyn AsKV<K, V>) -> Result<()> { + Ok(()) + } + fn on_drop_record(&self, _record: &dyn AsKV<K, V>) -> Result<()> { + Ok(()) + } + fn on_tx_begin(&self, _tx: &mut Tx) -> Result<()> { + Ok(()) + } + fn on_tx_precommit(&self, _tx: &mut Tx) -> Result<()> { + Ok(()) + } + fn on_tx_commit(&self) {} + } + + #[repr(C)] + #[derive(Copy, Clone, Pod, Debug)] + struct Value { + pub hba: BlockId, + pub key: Key, + pub mac: Mac, + } + + impl RecordKey<BlockId> for BlockId {} + impl RecordValue for Value {} + + #[test] + fn tx_lsm_tree_fns() -> Result<()> { + let nblocks = 102400; + let mem_disk = MemDisk::create(nblocks)?; + let tx_log_store = Arc::new(TxLogStore::format(mem_disk, Key::random())?); + let tx_lsm_tree: TxLsmTree<BlockId, Value, MemDisk> = + TxLsmTree::format(tx_log_store.clone(), Arc::new(Factory), None, None)?; + + // Put sufficient records which can trigger compaction before a sync command + let cap = MEMTABLE_CAPACITY; + let start = 0; + for i in start..start + cap { + let (k, v) = ( + i as BlockId, + Value { + hba: i as BlockId, + key: Key::random(), + mac: Mac::random(), + }, + ); + tx_lsm_tree.put(k, v)?; + } + let target_value = tx_lsm_tree.get(&5).unwrap(); + assert_eq!(target_value.hba, 5); + + tx_lsm_tree.sync()?; + + let target_value = tx_lsm_tree.get(&500).unwrap(); + assert_eq!(target_value.hba, 500); + + // Put sufficient records which can trigger compaction after a sync command + let start = 500; + for i in start..start + cap { + let (k, v) = ( + i as BlockId, + Value { + hba: (i * 2) as BlockId, + key: Key::random(), + mac: Mac::random(), + }, + ); + tx_lsm_tree.put(k, v)?; + } + + let target_value = tx_lsm_tree.get(&500).unwrap(); + assert_eq!(target_value.hba, 1000); + let target_value = tx_lsm_tree.get(&25).unwrap(); + assert_eq!(target_value.hba, 25); + + // Recover the `TxLsmTree`, all unsynced records should be discarded + drop(tx_lsm_tree); + let tx_lsm_tree: TxLsmTree<BlockId, Value, MemDisk> = + TxLsmTree::recover(tx_log_store.clone(), Arc::new(Factory), None, None)?; + + assert!(tx_lsm_tree.get(&(600 + cap)).is_err()); + + let cnt = 16; + let mut range_query_ctx = RangeQueryCtx::new(500, cnt); + tx_lsm_tree.get_range(&mut range_query_ctx).unwrap(); + let res = range_query_ctx.into_results(); + assert_eq!(res[0].1.hba, 500); + assert_eq!(res[cnt - 1].1.hba, 500 + cnt - 1); + Ok(()) + } +} diff --git a/kernel/comps/mlsdisk/src/layers/4-lsm/wal.rs b/kernel/comps/mlsdisk/src/layers/4-lsm/wal.rs new file mode 100644 index 00000000..ba24087b --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/4-lsm/wal.rs @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Transactions in WriteAhead Log. +use alloc::vec; +use core::{fmt::Debug, mem::size_of}; + +use ostd_pod::Pod; + +use super::{AsKV, SyncId}; +use crate::{ + layers::{ + bio::{BlockId, BlockSet, Buf, BufRef}, + log::{TxLog, TxLogId, TxLogStore}, + }, + os::Mutex, + prelude::*, + tx::CurrentTx, +}; + +/// The bucket name of WAL. +pub(super) const BUCKET_WAL: &str = "WAL"; + +/// WAL append TX in `TxLsmTree`. +/// +/// A `WalAppendTx` is used to append records, sync and discard WALs. +/// A WAL is storing, managing key-value records which are going to +/// put in `MemTable`. It's space is backed by a `TxLog` (L3). +#[derive(Clone)] +pub(super) struct WalAppendTx<D> { + inner: Arc<Mutex<WalTxInner<D>>>, +} + +struct WalTxInner<D> { + /// The appended WAL of ongoing Tx. + appended_log: Option<Arc<TxLog<D>>>, + /// Current log ID of WAL for later use. + log_id: Option<TxLogId>, + /// Store current sync ID as the first record of WAL. + sync_id: SyncId, + /// A buffer to cache appended records. + record_buf: Vec<u8>, + /// Store for WALs. + tx_log_store: Arc<TxLogStore<D>>, +} + +impl<D: BlockSet + 'static> WalAppendTx<D> { + const BUF_CAP: usize = 1024 * BLOCK_SIZE; + + /// Prepare a new WAL TX. + pub fn new(store: &Arc<TxLogStore<D>>, sync_id: SyncId) -> Self { + Self { + inner: Arc::new(Mutex::new(WalTxInner { + appended_log: None, + log_id: None, + sync_id, + record_buf: Vec::with_capacity(Self::BUF_CAP), + tx_log_store: store.clone(), + })), + } + } + + /// Append phase for an Append TX, mainly to append newly records to the WAL. + pub fn append<K: Pod, V: Pod>(&self, record: &dyn AsKV<K, V>) -> Result<()> { + let mut inner = self.inner.lock(); + if inner.appended_log.is_none() { + inner.prepare()?; + } + + { + let record_buf = &mut inner.record_buf; + record_buf.push(WalAppendFlag::Record as u8); + record_buf.extend_from_slice(record.key().as_bytes()); + record_buf.extend_from_slice(record.value().as_bytes()); + } + + const MAX_RECORD_SIZE: usize = 49; + if inner.record_buf.len() <= Self::BUF_CAP - MAX_RECORD_SIZE { + return Ok(()); + } + + inner.align_record_buf(); + let wal_tx = inner.tx_log_store.current_tx(); + let wal_log = inner.appended_log.as_ref().unwrap(); + self.flush_buf(&inner.record_buf, &wal_tx, wal_log)?; + inner.record_buf.clear(); + + Ok(()) + } + + /// Commit phase for an Append TX, mainly to commit (or abort) the TX. + /// After the committed WAL is sealed. Return the corresponding log ID. + /// + /// # Panics + /// + /// This method panics if current WAL's TX does not exist. + pub fn commit(&self) -> Result<TxLogId> { + let mut inner = self.inner.lock(); + let wal_log = inner + .appended_log + .take() + .expect("current WAL TX must exist"); + let wal_id = inner.log_id.take().unwrap(); + debug_assert_eq!(wal_id, wal_log.id()); + + if !inner.record_buf.is_empty() { + inner.align_record_buf(); + let wal_tx = inner.tx_log_store.current_tx(); + self.flush_buf(&inner.record_buf, &wal_tx, &wal_log)?; + inner.record_buf.clear(); + } + + drop(wal_log); + inner.tx_log_store.current_tx().commit()?; + Ok(wal_id) + } + + /// Appends current sync ID to WAL then commit the TX to ensure WAL's persistency. + /// Save the log ID for later appending. + pub fn sync(&self, sync_id: SyncId) -> Result<()> { + let mut inner = self.inner.lock(); + if inner.appended_log.is_none() { + inner.prepare()?; + } + inner.record_buf.push(WalAppendFlag::Sync as u8); + inner.record_buf.extend_from_slice(&sync_id.to_le_bytes()); + inner.sync_id = sync_id; + + inner.align_record_buf(); + let wal_log = inner.appended_log.take().unwrap(); + self.flush_buf( + &inner.record_buf, + &inner.tx_log_store.current_tx(), + &wal_log, + )?; + inner.record_buf.clear(); + + drop(wal_log); + inner.tx_log_store.current_tx().commit() + } + + /// Flushes the buffer to the backed log. + fn flush_buf( + &self, + record_buf: &[u8], + wal_tx: &CurrentTx<'_>, + log: &Arc<TxLog<D>>, + ) -> Result<()> { + debug_assert!(!record_buf.is_empty() && record_buf.len() % BLOCK_SIZE == 0); + let res = wal_tx.context(|| { + let buf = BufRef::try_from(record_buf).unwrap(); + log.append(buf) + }); + if res.is_err() { + wal_tx.abort(); + } + res + } + + /// Collects the synced records only and the maximum sync ID in the WAL. + pub fn collect_synced_records_and_sync_id<K: Pod, V: Pod>( + wal: &TxLog<D>, + ) -> Result<(Vec<(K, V)>, SyncId)> { + let nblocks = wal.nblocks(); + let mut records = Vec::new(); + + // TODO: Allocate separate buffers for large WAL + let mut buf = Buf::alloc(nblocks)?; + wal.read(0 as BlockId, buf.as_mut())?; + let buf_slice = buf.as_slice(); + + let k_size = size_of::<K>(); + let v_size = size_of::<V>(); + let total_bytes = nblocks * BLOCK_SIZE; + let mut offset = 0; + let (mut max_sync_id, mut synced_len) = (None, 0); + loop { + const MIN_RECORD_SIZE: usize = 9; + if offset > total_bytes - MIN_RECORD_SIZE { + break; + } + + let flag = WalAppendFlag::try_from(buf_slice[offset]); + offset += 1; + if flag.is_err() { + continue; + } + + match flag.unwrap() { + WalAppendFlag::Record => { + let record = { + let k = K::from_bytes(&buf_slice[offset..offset + k_size]); + let v = + V::from_bytes(&buf_slice[offset + k_size..offset + k_size + v_size]); + offset += k_size + v_size; + (k, v) + }; + + records.push(record); + } + WalAppendFlag::Sync => { + let sync_id = SyncId::from_le_bytes( + buf_slice[offset..offset + size_of::<SyncId>()] + .try_into() + .unwrap(), + ); + offset += size_of::<SyncId>(); + + let _ = max_sync_id.insert(sync_id); + synced_len = records.len(); + } + } + } + + if let Some(max_sync_id) = max_sync_id { + records.truncate(synced_len); + Ok((records, max_sync_id)) + } else { + Ok((vec![], 0)) + } + } +} + +impl<D: BlockSet + 'static> WalTxInner<D> { + /// Prepare phase for an Append TX, mainly to create new TX and WAL. + pub fn prepare(&mut self) -> Result<()> { + debug_assert!(self.appended_log.is_none()); + let appended_log = { + let store = &self.tx_log_store; + let wal_tx = store.new_tx(); + let log_id_opt = self.log_id; + let res = wal_tx.context(|| { + if let Some(log_id) = log_id_opt { + store.open_log(log_id, true) + } else { + store.create_log(BUCKET_WAL) + } + }); + if res.is_err() { + wal_tx.abort(); + } + let wal_log = res?; + let _ = self.log_id.insert(wal_log.id()); + wal_log + }; + let _ = self.appended_log.insert(appended_log); + + // Record the sync ID at the beginning of the WAL + debug_assert!(self.record_buf.is_empty()); + self.record_buf.push(WalAppendFlag::Sync as u8); + self.record_buf + .extend_from_slice(&self.sync_id.to_le_bytes()); + Ok(()) + } + + fn align_record_buf(&mut self) { + let aligned_len = align_up(self.record_buf.len(), BLOCK_SIZE); + self.record_buf.resize(aligned_len, 0); + } +} + +/// Two content kinds in a WAL. +#[derive(PartialEq, Eq, Debug)] +#[repr(u8)] +enum WalAppendFlag { + Record = 13, + Sync = 23, +} + +impl TryFrom<u8> for WalAppendFlag { + type Error = Error; + + fn try_from(value: u8) -> Result<Self> { + match value { + 13 => Ok(WalAppendFlag::Record), + 23 => Ok(WalAppendFlag::Sync), + _ => Err(Error::new(InvalidArgs)), + } + } +} diff --git a/kernel/comps/mlsdisk/src/layers/5-disk/bio.rs b/kernel/comps/mlsdisk/src/layers/5-disk/bio.rs new file mode 100644 index 00000000..70589f04 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/5-disk/bio.rs @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Block I/O (BIO). +use alloc::collections::VecDeque; +use core::{ + any::{Any, TypeId}, + ptr::NonNull, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use hashbrown::HashMap; + +use crate::{ + os::{Mutex, MutexGuard}, + prelude::*, + Buf, +}; + +/// A queue for managing block I/O requests (`BioReq`). +/// It provides a concurrency-safe way to store and manage +/// block I/O requests that need to be processed by a block device. +pub struct BioReqQueue { + queue: Mutex<VecDeque<BioReq>>, + num_reqs: AtomicUsize, +} + +impl BioReqQueue { + /// Create a new `BioReqQueue` instance. + pub fn new() -> Self { + Self { + queue: Mutex::new(VecDeque::new()), + num_reqs: AtomicUsize::new(0), + } + } + + /// Enqueue a block I/O request. + pub fn enqueue(&self, req: BioReq) -> Result<()> { + req.submit(); + self.queue.lock().push_back(req); + self.num_reqs.fetch_add(1, Ordering::Release); + Ok(()) + } + + /// Dequeue a block I/O request. + pub fn dequeue(&self) -> Option<BioReq> { + if let Some(req) = self.queue.lock().pop_front() { + self.num_reqs.fetch_sub(1, Ordering::Release); + Some(req) + } else { + debug_assert_eq!(self.num_reqs.load(Ordering::Acquire), 0); + None + } + } + + /// Returns the number of pending requests in this queue. + pub fn num_reqs(&self) -> usize { + self.num_reqs.load(Ordering::Acquire) + } + + /// Returns whether there are no pending requests in this queue. + pub fn is_empty(&self) -> bool { + self.num_reqs() == 0 + } +} + +/// A block I/O request. +pub struct BioReq { + type_: BioType, + addr: BlockId, + nblocks: u32, + bufs: Mutex<Vec<Buf>>, + status: Mutex<BioStatus>, + on_complete: Option<BioReqOnCompleteFn>, + ext: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>, +} + +/// The type of a block request. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BioType { + /// A read request. + Read, + /// A write request. + Write, + /// A sync request. + Sync, +} + +/// A response from a block device. +pub type BioResp = Result<()>; + +/// The type of the callback function invoked upon the completion of +/// a block I/O request. +pub type BioReqOnCompleteFn = fn(/* req = */ &BioReq, /* resp = */ &BioResp); + +/// The status describing a block I/O request. +#[derive(Clone, Debug)] +enum BioStatus { + Init, + Submitted, + Completed(BioResp), +} + +impl BioReq { + /// Returns the type of the request. + pub fn type_(&self) -> BioType { + self.type_ + } + + /// Returns the starting address of requested blocks. + /// + /// The return value is meaningless if the request is not a read or write. + pub fn addr(&self) -> BlockId { + self.addr + } + + /// Access the immutable buffers with a closure. + pub fn access_bufs_with<F, R>(&self, mut f: F) -> R + where + F: FnMut(&[Buf]) -> R, + { + let bufs = self.bufs.lock(); + (f)(&bufs) + } + + /// Access the mutable buffers with a closure. + pub(super) fn access_mut_bufs_with<F, R>(&self, mut f: F) -> R + where + F: FnMut(&mut [Buf]) -> R, + { + let mut bufs = self.bufs.lock(); + (f)(&mut bufs) + } + + /// Take the buffers out of the request. + pub(super) fn take_bufs(&self) -> Vec<Buf> { + let mut bufs = self.bufs.lock(); + let mut ret_bufs = Vec::new(); + core::mem::swap(&mut *bufs, &mut ret_bufs); + ret_bufs + } + + /// Returns the number of buffers associated with the request. + /// + /// If the request is a flush, then the returned value is meaningless. + pub fn nbufs(&self) -> usize { + self.bufs.lock().len() + } + + /// Returns the number of blocks to read or write by this request. + /// + /// If the request is a flush, then the returned value is meaningless. + pub fn nblocks(&self) -> usize { + self.nblocks as usize + } + + /// Returns the extensions of the request. + /// + /// The extensions of a request is a set of objects that may be added, removed, + /// or accessed by block devices and their users. Each of the extension objects + /// must have a different type. To avoid conflicts, it is recommended to use only + /// private types for the extension objects. + pub fn ext(&self) -> MutexGuard<HashMap<TypeId, Box<dyn Any + Send + Sync>>> { + self.ext.lock() + } + + /// Update the status of the request to "completed" by giving the response + /// to the request. + /// + /// After the invoking this API, the request is considered completed, which + /// means the request must have taken effect. For example, a completed read + /// request must have all its buffers filled with data. + /// + /// # Panics + /// + /// If the request has not been submitted yet, or has been completed already, + /// this method will panic. + pub(super) fn complete(&self, resp: BioResp) { + let mut status = self.status.lock(); + match *status { + BioStatus::Submitted => { + if let Some(on_complete) = self.on_complete { + (on_complete)(self, &resp); + } + + *status = BioStatus::Completed(resp); + } + _ => panic!("cannot complete before submitting or complete twice"), + } + } + + /// Mark the request as submitted. + pub(super) fn submit(&self) { + let mut status = self.status.lock(); + match *status { + BioStatus::Init => *status = BioStatus::Submitted, + _ => unreachable!(), + } + } +} + +/// A builder for `BioReq`. +pub struct BioReqBuilder { + type_: BioType, + addr: Option<BlockId>, + bufs: Option<Vec<Buf>>, + on_complete: Option<BioReqOnCompleteFn>, + ext: Option<HashMap<TypeId, Box<dyn Any + Send + Sync>>>, +} + +impl BioReqBuilder { + /// Creates a builder of a block request of the given type. + pub fn new(type_: BioType) -> Self { + Self { + type_, + addr: None, + bufs: None, + on_complete: None, + ext: None, + } + } + + /// Specify the block address of the request. + pub fn addr(mut self, addr: BlockId) -> Self { + self.addr = Some(addr); + self + } + + /// Give the buffers of the request. + pub fn bufs(mut self, bufs: Vec<Buf>) -> Self { + self.bufs = Some(bufs); + self + } + + /// Specify a callback invoked when the request is complete. + pub fn on_complete(mut self, on_complete: BioReqOnCompleteFn) -> Self { + self.on_complete = Some(on_complete); + self + } + + /// Add an extension object to the request. + pub fn ext<T: Any + Send + Sync + Sized>(mut self, obj: T) -> Self { + if self.ext.is_none() { + self.ext = Some(HashMap::new()); + } + let _ = self + .ext + .as_mut() + .unwrap() + .insert(TypeId::of::<T>(), Box::new(obj)); + self + } + + /// Build the request. + pub fn build(mut self) -> BioReq { + let type_ = self.type_; + if type_ == BioType::Sync { + debug_assert!( + self.addr.is_none(), + "addr is only meaningful for a read or write", + ); + debug_assert!( + self.bufs.is_none(), + "bufs is only meaningful for a read or write", + ); + } + + let addr = self.addr.unwrap_or(0 as BlockId); + + let bufs = self.bufs.take().unwrap_or_default(); + let nblocks = { + let nbytes = bufs + .iter() + .map(|buf| buf.as_slice().len()) + .fold(0_usize, |sum, len| sum.saturating_add(len)); + (nbytes / BLOCK_SIZE) as u32 + }; + + let ext = self.ext.take().unwrap_or_default(); + let on_complete = self.on_complete.take(); + + BioReq { + type_, + addr, + nblocks, + bufs: Mutex::new(bufs), + status: Mutex::new(BioStatus::Init), + on_complete, + ext: Mutex::new(ext), + } + } +} diff --git a/kernel/comps/mlsdisk/src/layers/5-disk/block_alloc.rs b/kernel/comps/mlsdisk/src/layers/5-disk/block_alloc.rs new file mode 100644 index 00000000..a8b2ac2c --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/5-disk/block_alloc.rs @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Block allocation. +use alloc::vec; +use core::{ + mem::size_of, + num::NonZeroUsize, + sync::atomic::{AtomicBool, AtomicUsize, Ordering}, +}; + +use ostd_pod::Pod; +use serde::{Deserialize, Serialize}; + +use super::sworndisk::Hba; +use crate::{ + layers::{ + bio::{BlockSet, Buf, BufRef, BID_SIZE}, + log::{TxLog, TxLogStore}, + }, + os::{BTreeMap, Condvar, CvarMutex, Mutex}, + prelude::*, + util::BitMap, +}; + +/// The bucket name of block validity table. +const BUCKET_BLOCK_VALIDITY_TABLE: &str = "BVT"; +/// The bucket name of block alloc/dealloc log. +const BUCKET_BLOCK_ALLOC_LOG: &str = "BAL"; + +/// Block validity table. Global allocator for `SwornDisk`, +/// which manages validities of user data blocks. +pub(super) struct AllocTable { + bitmap: Mutex<BitMap>, + next_avail: AtomicUsize, + nblocks: NonZeroUsize, + is_dirty: AtomicBool, + cvar: Condvar, + num_free: CvarMutex<usize>, +} + +/// Per-TX block allocator in `SwornDisk`, recording validities +/// of user data blocks within each TX. All metadata will be stored in +/// `TxLog`s of bucket `BAL` during TX for durability and recovery purpose. +pub(super) struct BlockAlloc<D> { + alloc_table: Arc<AllocTable>, // Point to the global allocator + diff_table: Mutex<BTreeMap<Hba, AllocDiff>>, // Per-TX diffs of block validity + store: Arc<TxLogStore<D>>, // Store for diff log from L3 + diff_log: Mutex<Option<Arc<TxLog<D>>>>, // Opened diff log (currently not in-use) +} + +/// Incremental diff of block validity. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u8)] +enum AllocDiff { + Alloc = 3, + Dealloc = 7, + Invalid, +} +const DIFF_RECORD_SIZE: usize = size_of::<AllocDiff>() + size_of::<Hba>(); + +impl AllocTable { + /// Create a new `AllocTable` given the total number of blocks. + pub fn new(nblocks: NonZeroUsize) -> Self { + Self { + bitmap: Mutex::new(BitMap::repeat(true, nblocks.get())), + next_avail: AtomicUsize::new(0), + nblocks, + is_dirty: AtomicBool::new(false), + cvar: Condvar::new(), + num_free: CvarMutex::new(nblocks.get()), + } + } + + /// Allocate a free slot for a new block, returns `None` + /// if there are no free slots. + pub fn alloc(&self) -> Option<Hba> { + let mut bitmap = self.bitmap.lock(); + let next_avail = self.next_avail.load(Ordering::Acquire); + + let hba = if let Some(hba) = bitmap.first_one(next_avail) { + hba + } else { + bitmap.first_one(0)? + }; + bitmap.set(hba, false); + + self.next_avail.store(hba + 1, Ordering::Release); + Some(hba as Hba) + } + + /// Allocate multiple free slots for a bunch of new blocks, returns `None` + /// if there are no free slots for all. + pub fn alloc_batch(&self, count: NonZeroUsize) -> Result<Vec<Hba>> { + let cnt = count.get(); + let mut num_free = self.num_free.lock().unwrap(); + while *num_free < cnt { + // TODO: May not be woken, may require manual triggering of a compaction in L4 + num_free = self.cvar.wait(num_free).unwrap(); + } + debug_assert!(*num_free >= cnt); + + let hbas = self.do_alloc_batch(count).unwrap(); + debug_assert_eq!(hbas.len(), cnt); + + *num_free -= cnt; + let _ = self + .is_dirty + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed); + Ok(hbas) + } + + fn do_alloc_batch(&self, count: NonZeroUsize) -> Option<Vec<Hba>> { + let count = count.get(); + debug_assert!(count > 0); + let mut bitmap = self.bitmap.lock(); + let mut next_avail = self.next_avail.load(Ordering::Acquire); + + if next_avail + count > self.nblocks.get() { + next_avail = bitmap.first_one(0)?; + } + + let hbas = if let Some(hbas) = bitmap.first_ones(next_avail, count) { + hbas + } else { + next_avail = bitmap.first_one(0)?; + bitmap.first_ones(next_avail, count)? + }; + hbas.iter().for_each(|hba| bitmap.set(*hba, false)); + + next_avail = hbas.last().unwrap() + 1; + self.next_avail.store(next_avail, Ordering::Release); + Some(hbas) + } + + /// Recover the `AllocTable` from the latest `BVT` log and a bunch of `BAL` logs + /// in the given store. + pub fn recover<D: BlockSet + 'static>( + nblocks: NonZeroUsize, + store: &Arc<TxLogStore<D>>, + ) -> Result<Self> { + let tx = store.new_tx(); + let res: Result<_> = tx.context(|| { + // Recover the block validity table from `BVT` log first + let bvt_log_res = store.open_log_in(BUCKET_BLOCK_VALIDITY_TABLE); + let mut bitmap = match bvt_log_res { + Ok(bvt_log) => { + let mut buf = Buf::alloc(bvt_log.nblocks())?; + bvt_log.read(0 as BlockId, buf.as_mut())?; + postcard::from_bytes(buf.as_slice()).map_err(|_| { + Error::with_msg(InvalidArgs, "deserialize block validity table failed") + })? + } + Err(e) => { + if e.errno() != NotFound { + return Err(e); + } + BitMap::repeat(true, nblocks.get()) + } + }; + + // Iterate each `BAL` log and apply each diff, from older to newer + let bal_log_ids_res = store.list_logs_in(BUCKET_BLOCK_ALLOC_LOG); + if let Err(e) = &bal_log_ids_res + && e.errno() == NotFound + { + let next_avail = bitmap.first_one(0).unwrap_or(0); + let num_free = bitmap.count_ones(); + return Ok(Self { + bitmap: Mutex::new(bitmap), + next_avail: AtomicUsize::new(next_avail), + nblocks, + is_dirty: AtomicBool::new(false), + cvar: Condvar::new(), + num_free: CvarMutex::new(num_free), + }); + } + let mut bal_log_ids = bal_log_ids_res?; + bal_log_ids.sort(); + + for bal_log_id in bal_log_ids { + let bal_log_res = store.open_log(bal_log_id, false); + if let Err(e) = &bal_log_res + && e.errno() == NotFound + { + continue; + } + let bal_log = bal_log_res?; + + let log_nblocks = bal_log.nblocks(); + let mut buf = Buf::alloc(log_nblocks)?; + bal_log.read(0 as BlockId, buf.as_mut())?; + let buf_slice = buf.as_slice(); + let mut offset = 0; + while offset <= log_nblocks * BLOCK_SIZE - DIFF_RECORD_SIZE { + let diff = AllocDiff::from(buf_slice[offset]); + offset += 1; + if diff == AllocDiff::Invalid { + continue; + } + let bid = BlockId::from_bytes(&buf_slice[offset..offset + BID_SIZE]); + offset += BID_SIZE; + match diff { + AllocDiff::Alloc => bitmap.set(bid, false), + AllocDiff::Dealloc => bitmap.set(bid, true), + _ => unreachable!(), + } + } + } + let next_avail = bitmap.first_one(0).unwrap_or(0); + let num_free = bitmap.count_ones(); + + Ok(Self { + bitmap: Mutex::new(bitmap), + next_avail: AtomicUsize::new(next_avail), + nblocks, + is_dirty: AtomicBool::new(false), + cvar: Condvar::new(), + num_free: CvarMutex::new(num_free), + }) + }); + let recov_self = res.map_err(|_| { + tx.abort(); + Error::with_msg(TxAborted, "recover block validity table TX aborted") + })?; + tx.commit()?; + + Ok(recov_self) + } + + /// Persist the block validity table to `BVT` log. GC all existed `BAL` logs. + pub fn do_compaction<D: BlockSet + 'static>(&self, store: &Arc<TxLogStore<D>>) -> Result<()> { + if !self.is_dirty.load(Ordering::Relaxed) { + return Ok(()); + } + + // Serialize the block validity table + let bitmap = self.bitmap.lock(); + const BITMAP_MAX_SIZE: usize = 1792 * BLOCK_SIZE; // TBD + let mut ser_buf = vec![0; BITMAP_MAX_SIZE]; + let ser_len = postcard::to_slice::<BitMap>(&bitmap, &mut ser_buf) + .map_err(|_| Error::with_msg(InvalidArgs, "serialize block validity table failed"))? + .len(); + ser_buf.resize(align_up(ser_len, BLOCK_SIZE), 0); + drop(bitmap); + + // Persist the serialized block validity table to `BVT` log + // and GC any old `BVT` logs and `BAL` logs + let tx = store.new_tx(); + let res: Result<_> = tx.context(|| { + if let Ok(bvt_log_ids) = store.list_logs_in(BUCKET_BLOCK_VALIDITY_TABLE) { + for bvt_log_id in bvt_log_ids { + store.delete_log(bvt_log_id)?; + } + } + + let bvt_log = store.create_log(BUCKET_BLOCK_VALIDITY_TABLE)?; + bvt_log.append(BufRef::try_from(&ser_buf[..]).unwrap())?; + + if let Ok(bal_log_ids) = store.list_logs_in(BUCKET_BLOCK_ALLOC_LOG) { + for bal_log_id in bal_log_ids { + store.delete_log(bal_log_id)?; + } + } + Ok(()) + }); + if res.is_err() { + tx.abort(); + return_errno_with_msg!(TxAborted, "persist block validity table TX aborted"); + } + tx.commit()?; + + self.is_dirty.store(false, Ordering::Relaxed); + Ok(()) + } + + /// Mark a specific slot deallocated. + pub fn set_deallocated(&self, nth: usize) { + let mut num_free = self.num_free.lock().unwrap(); + self.bitmap.lock().set(nth, true); + + *num_free += 1; + const AVG_ALLOC_COUNT: usize = 1024; + if *num_free >= AVG_ALLOC_COUNT { + self.cvar.notify_one(); + } + } +} + +impl<D: BlockSet + 'static> BlockAlloc<D> { + /// Create a new `BlockAlloc` with the given global allocator and store. + pub fn new(alloc_table: Arc<AllocTable>, store: Arc<TxLogStore<D>>) -> Self { + Self { + alloc_table, + diff_table: Mutex::new(BTreeMap::new()), + store, + diff_log: Mutex::new(None), + } + } + + /// Record a diff of `Alloc`. + pub fn alloc_block(&self, block_id: Hba) -> Result<()> { + let mut diff_table = self.diff_table.lock(); + let replaced = diff_table.insert(block_id, AllocDiff::Alloc); + debug_assert!( + replaced != Some(AllocDiff::Alloc), + "can't allocate a block twice" + ); + Ok(()) + } + + /// Record a diff of `Dealloc`. + pub fn dealloc_block(&self, block_id: Hba) -> Result<()> { + let mut diff_table = self.diff_table.lock(); + let replaced = diff_table.insert(block_id, AllocDiff::Dealloc); + debug_assert!( + replaced != Some(AllocDiff::Dealloc), + "can't deallocate a block twice" + ); + Ok(()) + } + + /// Prepare the block validity diff log. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn prepare_diff_log(&self) -> Result<()> { + // Do nothing for now + Ok(()) + } + + /// Persist the metadata in diff table to the block validity diff log. + /// + /// # Panics + /// + /// This method must be called within a TX. Otherwise, this method panics. + pub fn update_diff_log(&self) -> Result<()> { + let diff_table = self.diff_table.lock(); + if diff_table.is_empty() { + return Ok(()); + } + + let diff_log = self.store.create_log(BUCKET_BLOCK_ALLOC_LOG)?; + + const MAX_BUF_SIZE: usize = 1024 * BLOCK_SIZE; + let mut diff_buf = Vec::with_capacity(MAX_BUF_SIZE); + for (block_id, block_diff) in diff_table.iter() { + diff_buf.push(*block_diff as u8); + diff_buf.extend_from_slice(block_id.as_bytes()); + + if diff_buf.len() + DIFF_RECORD_SIZE > MAX_BUF_SIZE { + diff_buf.resize(align_up(diff_buf.len(), BLOCK_SIZE), 0); + diff_log.append(BufRef::try_from(&diff_buf[..]).unwrap())?; + diff_buf.clear(); + } + } + + if diff_buf.is_empty() { + return Ok(()); + } + diff_buf.resize(align_up(diff_buf.len(), BLOCK_SIZE), 0); + diff_log.append(BufRef::try_from(&diff_buf[..]).unwrap()) + } + + /// Update the metadata in diff table to the in-memory block validity table. + pub fn update_alloc_table(&self) { + let diff_table = self.diff_table.lock(); + let alloc_table = &self.alloc_table; + let mut num_free = alloc_table.num_free.lock().unwrap(); + let mut bitmap = alloc_table.bitmap.lock(); + let mut num_dealloc = 0_usize; + + for (block_id, block_diff) in diff_table.iter() { + match block_diff { + AllocDiff::Alloc => { + debug_assert!(!bitmap[*block_id]); + } + AllocDiff::Dealloc => { + debug_assert!(!bitmap[*block_id]); + bitmap.set(*block_id, true); + num_dealloc += 1; + } + AllocDiff::Invalid => unreachable!(), + }; + } + + *num_free += num_dealloc; + const AVG_ALLOC_COUNT: usize = 1024; + if *num_free >= AVG_ALLOC_COUNT { + alloc_table.cvar.notify_one(); + } + } +} + +impl From<u8> for AllocDiff { + fn from(value: u8) -> Self { + match value { + 3 => AllocDiff::Alloc, + 7 => AllocDiff::Dealloc, + _ => AllocDiff::Invalid, + } + } +} diff --git a/kernel/comps/mlsdisk/src/layers/5-disk/data_buf.rs b/kernel/comps/mlsdisk/src/layers/5-disk/data_buf.rs new file mode 100644 index 00000000..35bfb11c --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/5-disk/data_buf.rs @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Data buffering. +use core::ops::RangeInclusive; + +use super::sworndisk::RecordKey; +use crate::{ + layers::bio::{BufMut, BufRef}, + os::{BTreeMap, Condvar, CvarMutex, Mutex}, + prelude::*, +}; + +/// A buffer to cache data blocks before they are written to disk. +#[derive(Debug)] +pub(super) struct DataBuf { + buf: Mutex<BTreeMap<RecordKey, Arc<DataBlock>>>, + cap: usize, + cvar: Condvar, + is_full: CvarMutex<bool>, +} + +/// User data block. +pub(super) struct DataBlock([u8; BLOCK_SIZE]); + +impl DataBuf { + /// Create a new empty data buffer with a given capacity. + pub fn new(cap: usize) -> Self { + Self { + buf: Mutex::new(BTreeMap::new()), + cap, + cvar: Condvar::new(), + is_full: CvarMutex::new(false), + } + } + + /// Get the buffered data block with the key and copy + /// the content into `buf`. + pub fn get(&self, key: RecordKey, buf: &mut BufMut) -> Option<()> { + debug_assert_eq!(buf.nblocks(), 1); + if let Some(block) = self.buf.lock().get(&key) { + buf.as_mut_slice().copy_from_slice(block.as_slice()); + Some(()) + } else { + None + } + } + + /// Get the buffered data blocks which keys are within the given range. + pub fn get_range(&self, range: RangeInclusive<RecordKey>) -> Vec<(RecordKey, Arc<DataBlock>)> { + self.buf + .lock() + .iter() + .filter_map(|(k, v)| { + if range.contains(k) { + Some((*k, v.clone())) + } else { + None + } + }) + .collect() + } + + /// Put the data block in `buf` into the buffer. Return + /// whether the buffer is full after insertion. + pub fn put(&self, key: RecordKey, buf: BufRef) -> bool { + debug_assert_eq!(buf.nblocks(), 1); + + let mut is_full = self.is_full.lock().unwrap(); + while *is_full { + is_full = self.cvar.wait(is_full).unwrap(); + } + debug_assert!(!*is_full); + + let mut data_buf = self.buf.lock(); + let _ = data_buf.insert(key, DataBlock::from_buf(buf)); + + if data_buf.len() >= self.cap { + *is_full = true; + } + *is_full + } + + /// Return the number of data blocks of the buffer. + pub fn nblocks(&self) -> usize { + self.buf.lock().len() + } + + /// Return whether the buffer is full. + pub fn at_capacity(&self) -> bool { + self.nblocks() >= self.cap + } + + /// Return whether the buffer is empty. + pub fn is_empty(&self) -> bool { + self.nblocks() == 0 + } + + /// Empty the buffer. + pub fn clear(&self) { + let mut is_full = self.is_full.lock().unwrap(); + self.buf.lock().clear(); + if *is_full { + *is_full = false; + self.cvar.notify_all(); + } + } + + /// Return all the buffered data blocks. + pub fn all_blocks(&self) -> Vec<(RecordKey, Arc<DataBlock>)> { + self.buf + .lock() + .iter() + .map(|(k, v)| (*k, v.clone())) + .collect() + } +} + +impl DataBlock { + /// Create a new data block from the given `buf`. + pub fn from_buf(buf: BufRef) -> Arc<Self> { + debug_assert_eq!(buf.nblocks(), 1); + Arc::new(DataBlock(buf.as_slice().try_into().unwrap())) + } + + /// Return the immutable slice of the data block. + pub fn as_slice(&self) -> &[u8] { + &self.0 + } +} + +impl Debug for DataBlock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DataBlock") + .field("first 16 bytes", &&self.0[..16]) + .finish() + } +} diff --git a/kernel/comps/mlsdisk/src/layers/5-disk/mod.rs b/kernel/comps/mlsdisk/src/layers/5-disk/mod.rs new file mode 100644 index 00000000..ff6852c2 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/5-disk/mod.rs @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! The layer of secure virtual disk. +//! +//! `SwornDisk` provides three block I/O interfaces, `read()`, `write()` and `sync()`. +//! `SwornDisk` protects a logical block of user data using authenticated encryption. +//! The metadata of the encrypted logical blocks are inserted into a secure index `TxLsmTree`. +//! +//! `SwornDisk`'s backed untrusted host disk space is managed in `BlockAlloc`. Block reclamation can be +//! delayed to user-defined callbacks on `TxLsmTree`. +//! `SwornDisk` supports buffering written logical blocks. +//! +//! # Usage Example +//! +//! Write, sync then read blocks from `SwornDisk`. +//! +//! ``` +//! let nblocks = 1024; +//! let mem_disk = MemDisk::create(nblocks)?; +//! let root_key = Key::random(); +//! let sworndisk = SwornDisk::create(mem_disk.clone(), root_key)?; +//! +//! let num_rw = 128; +//! let mut rw_buf = Buf::alloc(1)?; +//! for i in 0..num_rw { +//! rw_buf.as_mut_slice().fill(i as u8); +//! sworndisk.write(i as Lba, rw_buf.as_ref())?; +//! } +//! sworndisk.sync()?; +//! for i in 0..num_rw { +//! sworndisk.read(i as Lba, rw_buf.as_mut())?; +//! assert_eq!(rw_buf.as_slice()[0], i as u8); +//! } +//! ``` + +mod bio; +mod block_alloc; +mod data_buf; +mod sworndisk; + +pub use self::sworndisk::SwornDisk; diff --git a/kernel/comps/mlsdisk/src/layers/5-disk/sworndisk.rs b/kernel/comps/mlsdisk/src/layers/5-disk/sworndisk.rs new file mode 100644 index 00000000..ac426131 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/5-disk/sworndisk.rs @@ -0,0 +1,881 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! SwornDisk as a block device. +//! +//! API: submit_bio(), submit_bio_sync(), create(), open(), +//! read(), readv(), write(), writev(), sync(). +//! +//! Responsible for managing a `TxLsmTree`, whereas the TX logs (WAL and SSTs) +//! are stored; an untrusted disk storing user data, a `BlockAlloc` for managing data blocks' +//! allocation metadata. `TxLsmTree` and `BlockAlloc` are manipulated +//! based on internal transactions. +use core::{ + num::NonZeroUsize, + ops::{Add, Sub}, + sync::atomic::{AtomicBool, Ordering}, +}; + +use ostd::mm::VmIo; +use ostd_pod::Pod; + +use super::{ + bio::{BioReq, BioReqQueue, BioResp, BioType}, + block_alloc::{AllocTable, BlockAlloc}, + data_buf::DataBuf, +}; +use crate::{ + layers::{ + bio::{BlockId, BlockSet, Buf, BufMut, BufRef}, + log::TxLogStore, + lsm::{ + AsKV, LsmLevel, RangeQueryCtx, RecordKey as RecordK, RecordValue as RecordV, + SyncIdStore, TxEventListener, TxEventListenerFactory, TxLsmTree, TxType, + }, + }, + os::{Aead, AeadIv as Iv, AeadKey as Key, AeadMac as Mac, RwLock}, + prelude::*, + tx::CurrentTx, +}; + +/// Logical Block Address. +pub type Lba = BlockId; +/// Host Block Address. +pub type Hba = BlockId; + +/// SwornDisk. +pub struct SwornDisk<D: BlockSet> { + inner: Arc<DiskInner<D>>, +} + +/// Inner structures of `SwornDisk`. +struct DiskInner<D: BlockSet> { + /// Block I/O request queue. + bio_req_queue: BioReqQueue, + /// A `TxLsmTree` to store metadata of the logical blocks. + logical_block_table: TxLsmTree<RecordKey, RecordValue, D>, + /// The underlying disk where user data is stored. + user_data_disk: D, + /// Manage space of the data disk. + block_validity_table: Arc<AllocTable>, + /// TX log store for managing logs in `TxLsmTree` and block alloc logs. + tx_log_store: Arc<TxLogStore<D>>, + /// A buffer to cache data blocks. + data_buf: DataBuf, + /// Root encryption key. + root_key: Key, + /// Whether `SwornDisk` is dropped. + is_dropped: AtomicBool, + /// Scope lock for control write and sync operation. + write_sync_region: RwLock<()>, +} + +impl<D: BlockSet + 'static> aster_block::BlockDevice for SwornDisk<D> { + fn enqueue( + &self, + bio: aster_block::bio::SubmittedBio, + ) -> core::result::Result<(), aster_block::bio::BioEnqueueError> { + use aster_block::bio::{BioStatus, BioType, SubmittedBio}; + + if bio.type_() == BioType::Discard { + warn!("discard operation not supported"); + bio.complete(BioStatus::NotSupported); + return Ok(()); + } + + if bio.type_() == BioType::Flush { + let status = match self.sync() { + Ok(_) => BioStatus::Complete, + Err(_) => BioStatus::IoError, + }; + bio.complete(status); + return Ok(()); + } + + let start_offset = bio.sid_range().start.to_offset(); + let start_lba = start_offset / BLOCK_SIZE; + let end_offset = bio.sid_range().end.to_offset(); + let end_lba = end_offset.div_ceil(BLOCK_SIZE); + let nblocks = end_lba - start_lba; + let Ok(buf) = Buf::alloc(nblocks) else { + bio.complete(BioStatus::NoSpace); + return Ok(()); + }; + + let handle_read_bio = |mut buf: Buf| { + if self.read(start_lba, buf.as_mut()).is_err() { + return BioStatus::IoError; + } + + let mut base = start_offset % BLOCK_SIZE; + bio.segments().iter().for_each(|seg| { + let offset = seg.nbytes(); + let _ = seg.write_bytes(0, &buf.as_slice()[base..base + offset]); + base += offset; + }); + BioStatus::Complete + }; + + let handle_write_bio = |mut buf: Buf| { + let mut base = start_offset % BLOCK_SIZE; + // Read the first unaligned block. + if base != 0 { + let buf_mut = BufMut::try_from(&mut buf.as_mut_slice()[..BLOCK_SIZE]).unwrap(); + if self.read(start_lba, buf_mut).is_err() { + return BioStatus::IoError; + } + } + + // Read the last unaligned block. + if end_offset % BLOCK_SIZE != 0 { + let offset = buf.as_slice().len() - BLOCK_SIZE; + let buf_mut = BufMut::try_from(&mut buf.as_mut_slice()[offset..]).unwrap(); + if self.read(end_lba - 1, buf_mut).is_err() { + return BioStatus::IoError; + } + } + + bio.segments().iter().for_each(|seg| { + let offset = seg.nbytes(); + let _ = seg.read_bytes(0, &mut buf.as_mut_slice()[base..base + offset]); + base += offset; + }); + + if self.write(start_lba, buf.as_ref()).is_err() { + return BioStatus::IoError; + } + BioStatus::Complete + }; + + let status = match bio.type_() { + BioType::Read => handle_read_bio(buf), + BioType::Write => handle_write_bio(buf), + _ => BioStatus::NotSupported, + }; + bio.complete(status); + Ok(()) + } + + fn metadata(&self) -> aster_block::BlockDeviceMeta { + use aster_block::{BlockDeviceMeta, BLOCK_SIZE, SECTOR_SIZE}; + + BlockDeviceMeta { + max_nr_segments_per_bio: usize::MAX, + nr_sectors: (BLOCK_SIZE / SECTOR_SIZE) * self.total_blocks(), + } + } +} + +impl<D: BlockSet + 'static> SwornDisk<D> { + /// Read a specified number of blocks at a logical block address on the device. + /// The block contents will be read into a single contiguous buffer. + pub fn read(&self, lba: Lba, buf: BufMut) -> Result<()> { + self.check_rw_args(lba, buf.nblocks())?; + self.inner.read(lba, buf) + } + + /// Read multiple blocks at a logical block address on the device. + /// The block contents will be read into several scattered buffers. + pub fn readv<'a>(&self, lba: Lba, bufs: &'a mut [BufMut<'a>]) -> Result<()> { + self.check_rw_args(lba, bufs.iter().fold(0, |acc, buf| acc + buf.nblocks()))?; + self.inner.readv(lba, bufs) + } + + /// Write a specified number of blocks at a logical block address on the device. + /// The block contents reside in a single contiguous buffer. + pub fn write(&self, lba: Lba, buf: BufRef) -> Result<()> { + self.check_rw_args(lba, buf.nblocks())?; + let _rguard = self.inner.write_sync_region.read(); + self.inner.write(lba, buf) + } + + /// Write multiple blocks at a logical block address on the device. + /// The block contents reside in several scattered buffers. + pub fn writev(&self, lba: Lba, bufs: &[BufRef]) -> Result<()> { + self.check_rw_args(lba, bufs.iter().fold(0, |acc, buf| acc + buf.nblocks()))?; + let _rguard = self.inner.write_sync_region.read(); + self.inner.writev(lba, bufs) + } + + /// Sync all cached data in the device to the storage medium for durability. + pub fn sync(&self) -> Result<()> { + let _wguard = self.inner.write_sync_region.write(); + // TODO: Error handling the sync operation + self.inner.sync().unwrap(); + + trace!("[SwornDisk] Sync completed. {self:?}"); + Ok(()) + } + + /// Returns the total number of blocks in the device. + pub fn total_blocks(&self) -> usize { + self.inner.user_data_disk.nblocks() + } + + /// Creates a new `SwornDisk` on the given disk, with the root encryption key. + pub fn create( + disk: D, + root_key: Key, + sync_id_store: Option<Arc<dyn SyncIdStore>>, + ) -> Result<Self> { + let data_disk = Self::subdisk_for_data(&disk)?; + let lsm_tree_disk = Self::subdisk_for_logical_block_table(&disk)?; + + let tx_log_store = Arc::new(TxLogStore::format(lsm_tree_disk, root_key)?); + let block_validity_table = Arc::new(AllocTable::new( + NonZeroUsize::new(data_disk.nblocks()).unwrap(), + )); + let listener_factory = Arc::new(TxLsmTreeListenerFactory::new( + tx_log_store.clone(), + block_validity_table.clone(), + )); + + let logical_block_table = { + let table = block_validity_table.clone(); + let on_drop_record_in_memtable = move |record: &dyn AsKV<RecordKey, RecordValue>| { + // Deallocate the host block while the corresponding record is dropped in `MemTable` + table.set_deallocated(record.value().hba); + }; + TxLsmTree::format( + tx_log_store.clone(), + listener_factory, + Some(Arc::new(on_drop_record_in_memtable)), + sync_id_store, + )? + }; + + let new_self = Self { + inner: Arc::new(DiskInner { + bio_req_queue: BioReqQueue::new(), + logical_block_table, + user_data_disk: data_disk, + block_validity_table, + tx_log_store, + data_buf: DataBuf::new(DATA_BUF_CAP), + root_key, + is_dropped: AtomicBool::new(false), + write_sync_region: RwLock::new(()), + }), + }; + + info!("[SwornDisk] Created successfully! {:?}", &new_self); + // XXX: Would `disk::drop()` bring unexpected behavior? + Ok(new_self) + } + + /// Opens the `SwornDisk` on the given disk, with the root encryption key. + pub fn open( + disk: D, + root_key: Key, + sync_id_store: Option<Arc<dyn SyncIdStore>>, + ) -> Result<Self> { + let data_disk = Self::subdisk_for_data(&disk)?; + let lsm_tree_disk = Self::subdisk_for_logical_block_table(&disk)?; + + let tx_log_store = Arc::new(TxLogStore::recover(lsm_tree_disk, root_key)?); + let block_validity_table = Arc::new(AllocTable::recover( + NonZeroUsize::new(data_disk.nblocks()).unwrap(), + &tx_log_store, + )?); + let listener_factory = Arc::new(TxLsmTreeListenerFactory::new( + tx_log_store.clone(), + block_validity_table.clone(), + )); + + let logical_block_table = { + let table = block_validity_table.clone(); + let on_drop_record_in_memtable = move |record: &dyn AsKV<RecordKey, RecordValue>| { + // Deallocate the host block while the corresponding record is dropped in `MemTable` + table.set_deallocated(record.value().hba); + }; + TxLsmTree::recover( + tx_log_store.clone(), + listener_factory, + Some(Arc::new(on_drop_record_in_memtable)), + sync_id_store, + )? + }; + + let opened_self = Self { + inner: Arc::new(DiskInner { + bio_req_queue: BioReqQueue::new(), + logical_block_table, + user_data_disk: data_disk, + block_validity_table, + data_buf: DataBuf::new(DATA_BUF_CAP), + tx_log_store, + root_key, + is_dropped: AtomicBool::new(false), + write_sync_region: RwLock::new(()), + }), + }; + + info!("[SwornDisk] Opened successfully! {:?}", &opened_self); + Ok(opened_self) + } + + /// Submit a new block I/O request and wait its completion (Synchronous). + pub fn submit_bio_sync(&self, bio_req: BioReq) -> BioResp { + bio_req.submit(); + self.inner.handle_bio_req(&bio_req) + } + // TODO: Support handling request asynchronously + + /// Check whether the arguments are valid for read/write operations. + fn check_rw_args(&self, lba: Lba, buf_nblocks: usize) -> Result<()> { + if lba + buf_nblocks > self.inner.user_data_disk.nblocks() { + Err(Error::with_msg( + OutOfDisk, + "read/write out of disk capacity", + )) + } else { + Ok(()) + } + } + + fn subdisk_for_data(disk: &D) -> Result<D> { + disk.subset(0..disk.nblocks() * 15 / 16) // TBD + } + + fn subdisk_for_logical_block_table(disk: &D) -> Result<D> { + disk.subset(disk.nblocks() * 15 / 16..disk.nblocks()) // TBD + } +} + +/// Capacity of the user data blocks buffer. +const DATA_BUF_CAP: usize = 1024; + +impl<D: BlockSet + 'static> DiskInner<D> { + /// Read a specified number of blocks at a logical block address on the device. + /// The block contents will be read into a single contiguous buffer. + pub fn read(&self, lba: Lba, buf: BufMut) -> Result<()> { + let nblocks = buf.nblocks(); + + let res = if nblocks == 1 { + self.read_one_block(lba, buf) + } else { + self.read_multi_blocks(lba, &mut [buf]) + }; + + // Allow empty read + if let Err(e) = &res + && e.errno() == NotFound + { + warn!("[SwornDisk] read contains empty read on lba {lba}"); + return Ok(()); + } + res + } + + /// Read multiple blocks at a logical block address on the device. + /// The block contents will be read into several scattered buffers. + pub fn readv<'a>(&self, lba: Lba, bufs: &'a mut [BufMut<'a>]) -> Result<()> { + let res = self.read_multi_blocks(lba, bufs); + + // Allow empty read + if let Err(e) = &res + && e.errno() == NotFound + { + warn!("[SwornDisk] readv contains empty read on lba {lba}"); + return Ok(()); + } + res + } + + fn read_one_block(&self, lba: Lba, mut buf: BufMut) -> Result<()> { + debug_assert_eq!(buf.nblocks(), 1); + // Search in `DataBuf` first + if self.data_buf.get(RecordKey { lba }, &mut buf).is_some() { + return Ok(()); + } + + // Search in `TxLsmTree` then + let value = self.logical_block_table.get(&RecordKey { lba })?; + + // Perform disk read and decryption + let mut cipher = Buf::alloc(1)?; + self.user_data_disk.read(value.hba, cipher.as_mut())?; + Aead::new().decrypt( + cipher.as_slice(), + &value.key, + &Iv::new_zeroed(), + &[], + &value.mac, + buf.as_mut_slice(), + )?; + + Ok(()) + } + + fn read_multi_blocks<'a>(&self, lba: Lba, bufs: &'a mut [BufMut<'a>]) -> Result<()> { + let mut buf_vec = BufMutVec::from_bufs(bufs); + let nblocks = buf_vec.nblocks(); + + let mut range_query_ctx = + RangeQueryCtx::<RecordKey, RecordValue>::new(RecordKey { lba }, nblocks); + + // Search in `DataBuf` first + for (key, data_block) in self + .data_buf + .get_range(range_query_ctx.range_uncompleted().unwrap()) + { + buf_vec + .nth_buf_mut_slice(key.lba - lba) + .copy_from_slice(data_block.as_slice()); + range_query_ctx.mark_completed(key); + } + if range_query_ctx.is_completed() { + return Ok(()); + } + + // Search in `TxLsmTree` then + self.logical_block_table.get_range(&mut range_query_ctx)?; + // Allow empty read + debug_assert!(range_query_ctx.is_completed()); + + let mut res = range_query_ctx.into_results(); + let record_batches = { + res.sort_by(|(_, v1), (_, v2)| v1.hba.cmp(&v2.hba)); + res.chunk_by(|(_, v1), (_, v2)| v2.hba - v1.hba == 1) + }; + + // Perform disk read in batches and decryption + let mut cipher_buf = Buf::alloc(nblocks)?; + let cipher_slice = cipher_buf.as_mut_slice(); + for record_batch in record_batches { + self.user_data_disk.read( + record_batch.first().unwrap().1.hba, + BufMut::try_from(&mut cipher_slice[..record_batch.len() * BLOCK_SIZE]).unwrap(), + )?; + + for (nth, (key, value)) in record_batch.iter().enumerate() { + Aead::new().decrypt( + &cipher_slice[nth * BLOCK_SIZE..(nth + 1) * BLOCK_SIZE], + &value.key, + &Iv::new_zeroed(), + &[], + &value.mac, + buf_vec.nth_buf_mut_slice(key.lba - lba), + )?; + } + } + + Ok(()) + } + + /// Write a specified number of blocks at a logical block address on the device. + /// The block contents reside in a single contiguous buffer. + pub fn write(&self, mut lba: Lba, buf: BufRef) -> Result<()> { + // Write block contents to `DataBuf` directly + for block_buf in buf.iter() { + let buf_at_capacity = self.data_buf.put(RecordKey { lba }, block_buf); + + // Flush all data blocks in `DataBuf` to disk if it's full + if buf_at_capacity { + // TODO: Error handling: Should discard current write in `DataBuf` + self.flush_data_buf()?; + } + lba += 1; + } + Ok(()) + } + + /// Write multiple blocks at a logical block address on the device. + /// The block contents reside in several scattered buffers. + pub fn writev(&self, mut lba: Lba, bufs: &[BufRef]) -> Result<()> { + for buf in bufs { + self.write(lba, *buf)?; + lba += buf.nblocks(); + } + Ok(()) + } + + fn flush_data_buf(&self) -> Result<()> { + let records = self.write_blocks_from_data_buf()?; + // Insert new records of data blocks to `TxLsmTree` + for (key, value) in records { + // TODO: Error handling: Should dealloc the written blocks + self.logical_block_table.put(key, value)?; + } + + self.data_buf.clear(); + Ok(()) + } + + fn write_blocks_from_data_buf(&self) -> Result<Vec<(RecordKey, RecordValue)>> { + let data_blocks = self.data_buf.all_blocks(); + + let num_write = data_blocks.len(); + let mut records = Vec::with_capacity(num_write); + if num_write == 0 { + return Ok(records); + } + + // Allocate slots for data blocks + let hbas = self + .block_validity_table + .alloc_batch(NonZeroUsize::new(num_write).unwrap())?; + debug_assert_eq!(hbas.len(), num_write); + let hba_batches = hbas.chunk_by(|hba1, hba2| hba2 - hba1 == 1); + + // Perform encryption and batch disk write + let mut cipher_buf = Buf::alloc(num_write)?; + let mut cipher_slice = cipher_buf.as_mut_slice(); + let mut nth = 0; + for hba_batch in hba_batches { + for (i, &hba) in hba_batch.iter().enumerate() { + let (lba, data_block) = &data_blocks[nth]; + let key = Key::random(); + let mac = Aead::new().encrypt( + data_block.as_slice(), + &key, + &Iv::new_zeroed(), + &[], + &mut cipher_slice[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE], + )?; + + records.push((*lba, RecordValue { hba, key, mac })); + nth += 1; + } + + self.user_data_disk.write( + *hba_batch.first().unwrap(), + BufRef::try_from(&cipher_slice[..hba_batch.len() * BLOCK_SIZE]).unwrap(), + )?; + cipher_slice = &mut cipher_slice[hba_batch.len() * BLOCK_SIZE..]; + } + + Ok(records) + } + + /// Sync all cached data in the device to the storage medium for durability. + pub fn sync(&self) -> Result<()> { + self.flush_data_buf()?; + debug_assert!(self.data_buf.is_empty()); + + self.logical_block_table.sync()?; + + // XXX: May impact performance when there comes frequent syncs + self.block_validity_table + .do_compaction(&self.tx_log_store)?; + + self.tx_log_store.sync()?; + + self.user_data_disk.flush() + } + + /// Handle one block I/O request. Mark the request completed when finished, + /// return any error that occurs. + pub fn handle_bio_req(&self, req: &BioReq) -> BioResp { + let res = match req.type_() { + BioType::Read => self.do_read(req), + BioType::Write => self.do_write(req), + BioType::Sync => self.do_sync(req), + }; + + req.complete(res.clone()); + res + } + + /// Handle a read I/O request. + fn do_read(&self, req: &BioReq) -> BioResp { + debug_assert_eq!(req.type_(), BioType::Read); + + let lba = req.addr() as Lba; + let mut req_bufs = req.take_bufs(); + let mut bufs = { + let mut bufs = Vec::with_capacity(req.nbufs()); + for buf in req_bufs.iter_mut() { + bufs.push(BufMut::try_from(buf.as_mut_slice())?); + } + bufs + }; + + if bufs.len() == 1 { + let buf = bufs.remove(0); + return self.read(lba, buf); + } + + self.readv(lba, &mut bufs) + } + + /// Handle a write I/O request. + fn do_write(&self, req: &BioReq) -> BioResp { + debug_assert_eq!(req.type_(), BioType::Write); + + let lba = req.addr() as Lba; + let req_bufs = req.take_bufs(); + let bufs = { + let mut bufs = Vec::with_capacity(req.nbufs()); + for buf in req_bufs.iter() { + bufs.push(BufRef::try_from(buf.as_slice())?); + } + bufs + }; + + self.writev(lba, &bufs) + } + + /// Handle a sync I/O request. + fn do_sync(&self, req: &BioReq) -> BioResp { + debug_assert_eq!(req.type_(), BioType::Sync); + self.sync() + } +} + +impl<D: BlockSet> Drop for SwornDisk<D> { + fn drop(&mut self) { + self.inner.is_dropped.store(true, Ordering::Release); + } +} + +impl<D: BlockSet + 'static> Debug for SwornDisk<D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SwornDisk") + .field("user_data_nblocks", &self.inner.user_data_disk.nblocks()) + .field("logical_block_table", &self.inner.logical_block_table) + .finish() + } +} + +/// A wrapper for `[BufMut]` used in `readv()`. +struct BufMutVec<'a> { + bufs: &'a mut [BufMut<'a>], + nblocks: usize, +} + +impl<'a> BufMutVec<'a> { + pub fn from_bufs(bufs: &'a mut [BufMut<'a>]) -> Self { + debug_assert!(!bufs.is_empty()); + let nblocks = bufs + .iter() + .map(|buf| buf.nblocks()) + .fold(0_usize, |sum, nblocks| sum.saturating_add(nblocks)); + Self { bufs, nblocks } + } + + pub fn nblocks(&self) -> usize { + self.nblocks + } + + pub fn nth_buf_mut_slice(&mut self, mut nth: usize) -> &mut [u8] { + debug_assert!(nth < self.nblocks); + for buf in self.bufs.iter_mut() { + let nblocks = buf.nblocks(); + if nth >= buf.nblocks() { + nth -= nblocks; + } else { + return &mut buf.as_mut_slice()[nth * BLOCK_SIZE..(nth + 1) * BLOCK_SIZE]; + } + } + &mut [] + } +} + +/// Listener factory for `TxLsmTree`. +struct TxLsmTreeListenerFactory<D> { + store: Arc<TxLogStore<D>>, + alloc_table: Arc<AllocTable>, +} + +impl<D> TxLsmTreeListenerFactory<D> { + fn new(store: Arc<TxLogStore<D>>, alloc_table: Arc<AllocTable>) -> Self { + Self { store, alloc_table } + } +} + +impl<D: BlockSet + 'static> TxEventListenerFactory<RecordKey, RecordValue> + for TxLsmTreeListenerFactory<D> +{ + fn new_event_listener( + &self, + tx_type: TxType, + ) -> Arc<dyn TxEventListener<RecordKey, RecordValue>> { + Arc::new(TxLsmTreeListener::new( + tx_type, + Arc::new(BlockAlloc::new( + self.alloc_table.clone(), + self.store.clone(), + )), + )) + } +} + +/// Event listener for `TxLsmTree`. +struct TxLsmTreeListener<D> { + tx_type: TxType, + block_alloc: Arc<BlockAlloc<D>>, +} + +impl<D> TxLsmTreeListener<D> { + fn new(tx_type: TxType, block_alloc: Arc<BlockAlloc<D>>) -> Self { + Self { + tx_type, + block_alloc, + } + } +} + +/// Register callbacks for different TXs in `TxLsmTree`. +impl<D: BlockSet + 'static> TxEventListener<RecordKey, RecordValue> for TxLsmTreeListener<D> { + fn on_add_record(&self, record: &dyn AsKV<RecordKey, RecordValue>) -> Result<()> { + match self.tx_type { + TxType::Compaction { + to_level: LsmLevel::L0, + } => self.block_alloc.alloc_block(record.value().hba), + // Major Compaction TX and Migration TX do not add new records + TxType::Compaction { .. } | TxType::Migration => { + // Do nothing + Ok(()) + } + } + } + + fn on_drop_record(&self, record: &dyn AsKV<RecordKey, RecordValue>) -> Result<()> { + match self.tx_type { + // Minor Compaction TX doesn't compact records + TxType::Compaction { + to_level: LsmLevel::L0, + } => { + unreachable!(); + } + TxType::Compaction { .. } | TxType::Migration => { + self.block_alloc.dealloc_block(record.value().hba) + } + } + } + + fn on_tx_begin(&self, tx: &mut CurrentTx<'_>) -> Result<()> { + match self.tx_type { + TxType::Compaction { .. } | TxType::Migration => { + tx.context(|| self.block_alloc.prepare_diff_log().unwrap()) + } + } + Ok(()) + } + + fn on_tx_precommit(&self, tx: &mut CurrentTx<'_>) -> Result<()> { + match self.tx_type { + TxType::Compaction { .. } | TxType::Migration => { + tx.context(|| self.block_alloc.update_diff_log().unwrap()) + } + } + Ok(()) + } + + fn on_tx_commit(&self) { + match self.tx_type { + TxType::Compaction { .. } | TxType::Migration => self.block_alloc.update_alloc_table(), + } + } +} + +/// Key-Value record for `TxLsmTree`. +pub(super) struct Record { + key: RecordKey, + value: RecordValue, +} + +/// The key of a `Record`. +#[repr(C)] +#[derive(Clone, Copy, Pod, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub(super) struct RecordKey { + /// Logical block address of user data block. + pub lba: Lba, +} + +/// The value of a `Record`. +#[repr(C)] +#[derive(Clone, Copy, Pod, Debug)] +pub(super) struct RecordValue { + /// Host block address of user data block. + pub hba: Hba, + /// Encryption key of the data block. + pub key: Key, + /// Encrypted MAC of the data block. + pub mac: Mac, +} + +impl Add<usize> for RecordKey { + type Output = Self; + + fn add(self, other: usize) -> Self::Output { + Self { + lba: self.lba + other, + } + } +} + +impl Sub<RecordKey> for RecordKey { + type Output = usize; + + fn sub(self, other: RecordKey) -> Self::Output { + self.lba - other.lba + } +} + +impl RecordK<RecordKey> for RecordKey {} +impl RecordV for RecordValue {} + +impl AsKV<RecordKey, RecordValue> for Record { + fn key(&self) -> &RecordKey { + &self.key + } + + fn value(&self) -> &RecordValue { + &self.value + } +} + +#[cfg(test)] +mod tests { + use core::ptr::NonNull; + use std::thread; + + use super::*; + use crate::layers::{bio::MemDisk, disk::bio::BioReqBuilder}; + + #[test] + fn sworndisk_fns() -> Result<()> { + let nblocks = 64 * 1024; + let mem_disk = MemDisk::create(nblocks)?; + let root_key = Key::random(); + // Create a new `SwornDisk` then do some writes + let sworndisk = SwornDisk::create(mem_disk.clone(), root_key, None)?; + let num_rw = 1024; + + // Submit a write block I/O request + let mut bufs = Vec::with_capacity(num_rw); + (0..num_rw).for_each(|i| { + let mut buf = Buf::alloc(1).unwrap(); + buf.as_mut_slice().fill(i as u8); + bufs.push(buf); + }); + let bio_req = BioReqBuilder::new(BioType::Write) + .addr(0 as BlockId) + .bufs(bufs) + .build(); + sworndisk.submit_bio_sync(bio_req)?; + + // Sync the `SwornDisk` then do some reads + sworndisk.submit_bio_sync(BioReqBuilder::new(BioType::Sync).build())?; + + let mut rbuf = Buf::alloc(1)?; + for i in 0..num_rw { + sworndisk.read(i as Lba, rbuf.as_mut())?; + assert_eq!(rbuf.as_slice()[0], i as u8); + } + + // Open the closed `SwornDisk` then test its data's existence + drop(sworndisk); + thread::spawn(move || -> Result<()> { + let opened_sworndisk = SwornDisk::open(mem_disk, root_key, None)?; + let mut rbuf = Buf::alloc(2)?; + opened_sworndisk.read(5 as Lba, rbuf.as_mut())?; + assert_eq!(rbuf.as_slice()[0], 5u8); + assert_eq!(rbuf.as_slice()[4096], 6u8); + Ok(()) + }) + .join() + .unwrap() + } +} diff --git a/kernel/comps/mlsdisk/src/layers/mod.rs b/kernel/comps/mlsdisk/src/layers/mod.rs new file mode 100644 index 00000000..23b6c562 --- /dev/null +++ b/kernel/comps/mlsdisk/src/layers/mod.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MPL-2.0 + +#[path = "0-bio/mod.rs"] +pub mod bio; +#[path = "1-crypto/mod.rs"] +pub mod crypto; +#[path = "5-disk/mod.rs"] +pub mod disk; +#[path = "2-edit/mod.rs"] +pub mod edit; +#[path = "3-log/mod.rs"] +pub mod log; +#[path = "4-lsm/mod.rs"] +pub mod lsm; diff --git a/kernel/comps/mlsdisk/src/lib.rs b/kernel/comps/mlsdisk/src/lib.rs new file mode 100644 index 00000000..d430dffd --- /dev/null +++ b/kernel/comps/mlsdisk/src/lib.rs @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MPL-2.0 + +#![no_std] +#![deny(unsafe_code)] +#![feature(let_chains)] +#![feature(negative_impls)] +#![feature(slice_as_chunks)] +#![allow(dead_code, unused_imports)] + +mod error; +mod layers; +mod os; +mod prelude; +mod tx; +mod util; + +extern crate alloc; + +pub use self::{ + error::{Errno, Error}, + layers::{ + bio::{BlockId, BlockSet, Buf, BufMut, BufRef, BLOCK_SIZE}, + disk::SwornDisk, + }, + os::{Aead, AeadIv, AeadKey, AeadMac, Rng}, + util::{Aead as _, RandomInit, Rng as _}, +}; diff --git a/kernel/comps/mlsdisk/src/os/mod.rs b/kernel/comps/mlsdisk/src/os/mod.rs new file mode 100644 index 00000000..cba877a8 --- /dev/null +++ b/kernel/comps/mlsdisk/src/os/mod.rs @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! OS-specific or OS-dependent APIs. + +pub use alloc::{ + boxed::Box, + collections::BTreeMap, + string::{String, ToString}, + sync::{Arc, Weak}, + vec::Vec, +}; +use core::{ + fmt, + sync::atomic::{AtomicBool, Ordering}, +}; + +use aes_gcm::{ + aead::{AeadInPlace, Key, NewAead, Nonce, Tag}, + aes::Aes128, + Aes128Gcm, +}; +use ctr::cipher::{NewCipher, StreamCipher}; +pub use hashbrown::{HashMap, HashSet}; +pub use ostd::sync::{Mutex, MutexGuard, RwLock, SpinLock}; +use ostd::{ + arch::read_random, + sync::{self, PreemptDisabled, WaitQueue}, + task::{Task, TaskOptions}, +}; +use ostd_pod::Pod; +use serde::{Deserialize, Serialize}; + +use crate::{ + error::{Errno, Error}, + prelude::Result, +}; + +pub type RwLockReadGuard<'a, T> = sync::RwLockReadGuard<'a, T, PreemptDisabled>; +pub type RwLockWriteGuard<'a, T> = sync::RwLockWriteGuard<'a, T, PreemptDisabled>; +pub type SpinLockGuard<'a, T> = sync::SpinLockGuard<'a, T, PreemptDisabled>; +pub type Tid = u32; + +/// A struct to get a unique identifier for the current thread. +pub struct CurrentThread; + +impl CurrentThread { + /// Returns the Tid of current kernel thread. + pub fn id() -> Tid { + let Some(task) = Task::current() else { + return 0; + }; + + task.data() as *const _ as u32 + } +} + +/// A `Condvar` (Condition Variable) is a synchronization primitive that can block threads +/// until a certain condition becomes true. +/// +/// This is a copy from `aster-nix`. +pub struct Condvar { + waitqueue: Arc<WaitQueue>, + counter: SpinLock<Inner>, +} + +struct Inner { + waiter_count: u64, + notify_count: u64, +} + +impl Condvar { + /// Creates a new condition variable. + pub fn new() -> Self { + Condvar { + waitqueue: Arc::new(WaitQueue::new()), + counter: SpinLock::new(Inner { + waiter_count: 0, + notify_count: 0, + }), + } + } + + /// Atomically releases the given `MutexGuard`, + /// blocking the current thread until the condition variable + /// is notified, after which the mutex will be reacquired. + /// + /// Returns a new `MutexGuard` if the operation is successful, + /// or returns the provided guard + /// within a `LockErr` if the waiting operation fails. + pub fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> Result<MutexGuard<'a, T>> { + let cond = || { + // Check if the notify counter is greater than 0. + let mut counter = self.counter.lock(); + if counter.notify_count > 0 { + // Decrement the notify counter. + counter.notify_count -= 1; + Some(()) + } else { + None + } + }; + { + let mut counter = self.counter.lock(); + counter.waiter_count += 1; + } + let lock = MutexGuard::get_lock(&guard); + drop(guard); + self.waitqueue.wait_until(cond); + Ok(lock.lock()) + } + + /// Wakes up one blocked thread waiting on this condition variable. + /// + /// If there is a waiting thread, it will be unblocked + /// and allowed to reacquire the associated mutex. + /// If no threads are waiting, this function is a no-op. + pub fn notify_one(&self) { + let mut counter = self.counter.lock(); + if counter.waiter_count == 0 { + return; + } + counter.notify_count += 1; + self.waitqueue.wake_one(); + counter.waiter_count -= 1; + } + + /// Wakes up all blocked threads waiting on this condition variable. + /// + /// This method will unblock all waiting threads + /// and they will be allowed to reacquire the associated mutex. + /// If no threads are waiting, this function is a no-op. + pub fn notify_all(&self) { + let mut counter = self.counter.lock(); + if counter.waiter_count == 0 { + return; + } + counter.notify_count = counter.waiter_count; + self.waitqueue.wake_all(); + counter.waiter_count = 0; + } +} + +impl fmt::Debug for Condvar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Condvar").finish_non_exhaustive() + } +} + +/// Wrap the `Mutex` provided by kernel, used for `Condvar`. +#[repr(transparent)] +pub struct CvarMutex<T> { + inner: Mutex<T>, +} + +// TODO: add distinguish guard type for `CvarMutex` if needed. + +impl<T> CvarMutex<T> { + /// Constructs a new `Mutex` lock, using the kernel's `struct mutex`. + pub fn new(t: T) -> Self { + Self { + inner: Mutex::new(t), + } + } + + /// Acquires the lock and gives the caller access to the data protected by it. + pub fn lock(&self) -> Result<MutexGuard<'_, T>> { + let guard = self.inner.lock(); + Ok(guard) + } +} + +impl<T: fmt::Debug> fmt::Debug for CvarMutex<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("No data, since `CvarMutex` does't support `try_lock` now") + } +} + +/// Spawns a new thread, returning a `JoinHandle` for it. +pub fn spawn<F, T>(f: F) -> JoinHandle<T> +where + F: FnOnce() -> T + Send + Sync + 'static, + T: Send + 'static, +{ + let is_finished = Arc::new(AtomicBool::new(false)); + let data = Arc::new(SpinLock::new(None)); + + let is_finished_clone = is_finished.clone(); + let data_clone = data.clone(); + let task = TaskOptions::new(move || { + let data = f(); + *data_clone.lock() = Some(data); + is_finished_clone.store(true, Ordering::Release); + }) + .spawn() + .unwrap(); + + JoinHandle { + task, + is_finished, + data, + } +} + +/// An owned permission to join on a thread (block on its termination). +/// +/// This struct is created by the `spawn` function. +pub struct JoinHandle<T> { + task: Arc<Task>, + is_finished: Arc<AtomicBool>, + data: Arc<SpinLock<Option<T>>>, +} + +impl<T> JoinHandle<T> { + /// Checks if the associated thread has finished running its main function. + pub fn is_finished(&self) -> bool { + self.is_finished.load(Ordering::Acquire) + } + + /// Waits for the associated thread to finish. + pub fn join(self) -> Result<T> { + while !self.is_finished() { + Task::yield_now(); + } + + let data = self.data.lock().take().unwrap(); + Ok(data) + } +} + +impl<T> fmt::Debug for JoinHandle<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JoinHandle").finish_non_exhaustive() + } +} + +/// A random number generator. +pub struct Rng; + +impl crate::util::Rng for Rng { + fn new(_seed: &[u8]) -> Self { + Self + } + + fn fill_bytes(&self, dest: &mut [u8]) -> Result<()> { + let (chunks, remain) = dest.as_chunks_mut::<8>(); + chunks.iter_mut().for_each(|chunk| { + chunk.copy_from_slice(read_random().unwrap_or(0u64).as_bytes()); + }); + remain.copy_from_slice(&read_random().unwrap_or(0u64).as_bytes()[..remain.len()]); + Ok(()) + } +} + +/// A macro to define byte_array_types used by `Aead` or `Skcipher`. +macro_rules! new_byte_array_type { + ($name:ident, $n:expr) => { + #[repr(C)] + #[derive(Copy, Clone, Pod, Debug, Default, Deserialize, Serialize)] + pub struct $name([u8; $n]); + + impl core::ops::Deref for $name { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.0.as_slice() + } + } + + impl core::ops::DerefMut for $name { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut_slice() + } + } + + impl crate::util::RandomInit for $name { + fn random() -> Self { + use crate::util::Rng; + + let mut result = Self::default(); + let rng = self::Rng::new(&[]); + rng.fill_bytes(&mut result).unwrap_or_default(); + result + } + } + }; +} + +const AES_GCM_KEY_SIZE: usize = 16; +const AES_GCM_IV_SIZE: usize = 12; +const AES_GCM_MAC_SIZE: usize = 16; + +new_byte_array_type!(AeadKey, AES_GCM_KEY_SIZE); +new_byte_array_type!(AeadIv, AES_GCM_IV_SIZE); +new_byte_array_type!(AeadMac, AES_GCM_MAC_SIZE); + +/// An `AEAD` cipher. +#[derive(Debug, Default)] +pub struct Aead; + +impl Aead { + /// Construct an `Aead` instance. + pub fn new() -> Self { + Self + } +} + +impl crate::util::Aead for Aead { + type Key = AeadKey; + type Iv = AeadIv; + type Mac = AeadMac; + + fn encrypt( + &self, + input: &[u8], + key: &AeadKey, + iv: &AeadIv, + aad: &[u8], + output: &mut [u8], + ) -> Result<AeadMac> { + let key = Key::<Aes128Gcm>::from_slice(key); + let nonce = Nonce::<Aes128Gcm>::from_slice(iv); + let cipher = Aes128Gcm::new(key); + + output.copy_from_slice(input); + let tag = cipher + .encrypt_in_place_detached(nonce, aad, output) + .map_err(|_| Error::with_msg(Errno::EncryptFailed, "aes-128-gcm encryption failed"))?; + + let mut aead_mac = AeadMac::new_zeroed(); + aead_mac.copy_from_slice(&tag); + Ok(aead_mac) + } + + fn decrypt( + &self, + input: &[u8], + key: &AeadKey, + iv: &AeadIv, + aad: &[u8], + mac: &AeadMac, + output: &mut [u8], + ) -> Result<()> { + let key = Key::<Aes128Gcm>::from_slice(key); + let nonce = Nonce::<Aes128Gcm>::from_slice(iv); + let tag = Tag::<Aes128Gcm>::from_slice(mac); + let cipher = Aes128Gcm::new(key); + + output.copy_from_slice(input); + cipher + .decrypt_in_place_detached(nonce, aad, output, tag) + .map_err(|_| Error::with_msg(Errno::DecryptFailed, "aes-128-gcm decryption failed")) + } +} + +type Aes128Ctr = ctr::Ctr128LE<Aes128>; + +const AES_CTR_KEY_SIZE: usize = 16; +const AES_CTR_IV_SIZE: usize = 16; + +new_byte_array_type!(SkcipherKey, AES_CTR_KEY_SIZE); +new_byte_array_type!(SkcipherIv, AES_CTR_IV_SIZE); + +/// A symmetric key cipher. +#[derive(Debug, Default)] +pub struct Skcipher; + +// TODO: impl `Skcipher` with linux kernel Crypto API. +impl Skcipher { + /// Construct a `Skcipher` instance. + pub fn new() -> Self { + Self + } +} + +impl crate::util::Skcipher for Skcipher { + type Key = SkcipherKey; + type Iv = SkcipherIv; + + fn encrypt( + &self, + input: &[u8], + key: &SkcipherKey, + iv: &SkcipherIv, + output: &mut [u8], + ) -> Result<()> { + let mut cipher = Aes128Ctr::new_from_slices(key, iv).unwrap(); + output.copy_from_slice(input); + cipher.apply_keystream(output); + Ok(()) + } + + fn decrypt( + &self, + input: &[u8], + key: &SkcipherKey, + iv: &SkcipherIv, + output: &mut [u8], + ) -> Result<()> { + let mut cipher = Aes128Ctr::new_from_slices(key, iv).unwrap(); + output.copy_from_slice(input); + cipher.apply_keystream(output); + Ok(()) + } +} diff --git a/kernel/comps/mlsdisk/src/prelude.rs b/kernel/comps/mlsdisk/src/prelude.rs new file mode 100644 index 00000000..ef5b58f5 --- /dev/null +++ b/kernel/comps/mlsdisk/src/prelude.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MPL-2.0 + +pub(crate) use crate::{ + error::{Errno::*, Error}, + layers::bio::{BlockId, BLOCK_SIZE}, + os::{Arc, Box, String, ToString, Vec, Weak}, + return_errno, return_errno_with_msg, + util::{align_down, align_up, Aead as _, RandomInit, Rng as _, Skcipher as _}, +}; + +pub(crate) type Result<T> = core::result::Result<T, Error>; + +pub(crate) use core::fmt::{self, Debug}; + +pub(crate) use log::{debug, error, info, trace, warn}; diff --git a/kernel/comps/mlsdisk/src/tx/current.rs b/kernel/comps/mlsdisk/src/tx/current.rs new file mode 100644 index 00000000..b436f5ff --- /dev/null +++ b/kernel/comps/mlsdisk/src/tx/current.rs @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Get and set the current transaction of the current thread. +use core::sync::atomic::Ordering::{Acquire, Release}; + +use super::{Tx, TxData, TxId, TxProvider, TxStatus}; +use crate::{os::CurrentThread, prelude::*}; + +/// The current transaction on a thread. +#[derive(Clone)] +pub struct CurrentTx<'a> { + provider: &'a TxProvider, +} + +// CurrentTx is only useful and valid for the current thread +impl !Send for CurrentTx<'_> {} +impl !Sync for CurrentTx<'_> {} + +impl<'a> CurrentTx<'a> { + pub(super) fn new(provider: &'a TxProvider) -> Self { + Self { provider } + } + + /// Enter the context of the current TX. + /// + /// While within the context of a TX, the implementation side of a TX + /// can get the current TX via `TxProvider::current`. + pub fn context<F, R>(&self, f: F) -> R + where + F: FnOnce() -> R, + { + let tx_table = self.provider.tx_table.lock(); + let tid = CurrentThread::id(); + if !tx_table.contains_key(&tid) { + panic!("there should be one Tx exited on the current thread"); + } + + assert!(tx_table.get(&tid).unwrap().status() == TxStatus::Ongoing); + drop(tx_table); + + f() + } + + /// Commits the current TX. + /// + /// If the returned value is `Ok`, then the TX is committed successfully. + /// Otherwise, the TX is aborted. + pub fn commit(&self) -> Result<()> { + let mut tx_table = self.provider.tx_table.lock(); + let Some(mut tx) = tx_table.remove(&CurrentThread::id()) else { + panic!("there should be one Tx exited on the current thread"); + }; + debug_assert!(tx.status() == TxStatus::Ongoing); + + let res = self.provider.call_precommit_handlers(); + if res.is_ok() { + self.provider.call_commit_handlers(); + tx.set_status(TxStatus::Committed); + } else { + self.provider.call_abort_handlers(); + tx.set_status(TxStatus::Aborted); + } + + res + } + + /// Aborts the current TX. + pub fn abort(&self) { + let mut tx_table = self.provider.tx_table.lock(); + let Some(mut tx) = tx_table.remove(&CurrentThread::id()) else { + panic!("there should be one Tx exited on the current thread"); + }; + debug_assert!(tx.status() == TxStatus::Ongoing); + + self.provider.call_abort_handlers(); + tx.set_status(TxStatus::Aborted); + } + + /// The ID of the transaction. + pub fn id(&self) -> TxId { + self.get_current_mut_with(|tx| tx.id()) + } + + /// Get immutable access to some type of the per-transaction data within a closure. + /// + /// # Panics + /// + /// The `data_with` method must _not_ be called recursively. + pub fn data_with<T: TxData, F, R>(&self, f: F) -> R + where + F: FnOnce(&T) -> R, + { + self.get_current_mut_with(|tx| { + let data = tx.data::<T>(); + f(data) + }) + } + + /// Get mutable access to some type of the per-transaction data within a closure. + pub fn data_mut_with<T: TxData, F, R>(&mut self, f: F) -> R + where + F: FnOnce(&mut T) -> R, + { + self.get_current_mut_with(|tx| { + let data = tx.data_mut::<T>(); + f(data) + }) + } + + /// Get a _mutable_ reference to the current transaction of the current thread, + /// passing it to a given closure. + /// + /// # Panics + /// + /// The `get_current_mut_with` method must be called within the closure + /// of `set_and_exec_with`. + /// + /// In addition, the `get_current_mut_with` method must _not_ be called + /// recursively. + #[allow(dropping_references)] + fn get_current_mut_with<F, R>(&self, f: F) -> R + where + F: FnOnce(&mut Tx) -> R, + { + let mut tx_table = self.provider.tx_table.lock(); + let Some(tx) = tx_table.get_mut(&CurrentThread::id()) else { + panic!("there should be one Tx exited on the current thread"); + }; + + if tx.is_accessing_data.swap(true, Acquire) { + panic!("get_current_mut_with must not be called recursively"); + } + + let retval: R = f(tx); + + // SAFETY. At any given time, at most one mutable reference will be constructed + // between the Acquire-Release section. And it is safe to drop `&mut Tx` after + // `Release`, since drop the reference does nothing to the `Tx` itself. + tx.is_accessing_data.store(false, Release); + + retval + } +} diff --git a/kernel/comps/mlsdisk/src/tx/mod.rs b/kernel/comps/mlsdisk/src/tx/mod.rs new file mode 100644 index 00000000..335f6466 --- /dev/null +++ b/kernel/comps/mlsdisk/src/tx/mod.rs @@ -0,0 +1,435 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Transaction management. +//! +//! Transaction management APIs serve two sides: +//! +//! * The user side of TXs uses `Tx` to use, commit, or abort TXs. +//! * The implementation side of TXs uses `TxProvider` to get notified +//! when TXs are created, committed, or aborted by register callbacks. +mod current; + +use core::{ + any::{Any, TypeId}, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, +}; + +pub use self::current::CurrentTx; +use crate::{ + os::{CurrentThread, HashMap, Mutex, RwLock, Tid}, + prelude::*, +}; + +/// A transaction provider. +#[allow(clippy::type_complexity)] +pub struct TxProvider { + id: u64, + initializer_map: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>, + precommit_handlers: RwLock<Vec<Box<dyn Fn(CurrentTx<'_>) -> Result<()> + Send + Sync>>>, + commit_handlers: RwLock<Vec<Box<dyn Fn(CurrentTx<'_>) + Send + Sync>>>, + abort_handlers: RwLock<Vec<Box<dyn Fn(CurrentTx<'_>) + Send + Sync>>>, + weak_self: Weak<Self>, + tx_table: Mutex<HashMap<Tid, Tx>>, +} + +impl TxProvider { + /// Creates a new TX provider. + pub fn new() -> Arc<Self> { + static NEXT_ID: AtomicU64 = AtomicU64::new(0); + Arc::new_cyclic(|weak_self| Self { + id: NEXT_ID.fetch_add(1, Ordering::Release), + initializer_map: RwLock::new(HashMap::new()), + precommit_handlers: RwLock::new(Vec::new()), + commit_handlers: RwLock::new(Vec::new()), + abort_handlers: RwLock::new(Vec::new()), + weak_self: weak_self.clone(), + tx_table: Mutex::new(HashMap::new()), + }) + } + + /// Creates a new TX that is attached to this TX provider. + pub fn new_tx(&self) -> CurrentTx<'_> { + let mut tx_table = self.tx_table.lock(); + let tid = CurrentThread::id(); + if tx_table.contains_key(&tid) { + return self.current(); + } + + let tx = Tx::new(self.weak_self.clone()); + let _ = tx_table.insert(tid, tx); + self.current() + } + + /// Get the current TX. + /// + /// # Panics + /// + /// The caller of this method must be within the closure passed to + /// `Tx::context`. Otherwise, the method would panic. + pub fn current(&self) -> CurrentTx<'_> { + CurrentTx::new(self) + } + + /// Register a per-TX data initializer. + /// + /// The registered initializer function will be called upon the creation of + /// a TX. + pub fn register_data_initializer<T>(&self, f: Box<dyn Fn() -> T + Send + Sync>) + where + T: TxData, + { + let mut initializer_map = self.initializer_map.write(); + initializer_map.insert(TypeId::of::<T>(), Box::new(f)); + } + + fn init_data<T>(&self) -> T + where + T: TxData, + { + let initializer_map = self.initializer_map.read(); + let init_fn = initializer_map + .get(&TypeId::of::<T>()) + .unwrap() + .downcast_ref::<Box<dyn Fn() -> T>>() + .unwrap(); + init_fn() + } + + /// Register a callback for the pre-commit stage, + /// which is before the commit stage. + /// + /// Committing a TX triggers the pre-commit stage as well as the commit + /// stage of the TX. + /// On the pre-commit stage, the register callbacks will be called. + /// Pre-commit callbacks are allowed to fail (unlike commit callbacks). + /// If any pre-commit callbacks failed, the TX would be aborted and + /// the commit callbacks would not get called. + pub fn register_precommit_handler<F>(&self, f: F) + where + F: Fn(CurrentTx<'_>) -> Result<()> + Send + Sync + 'static, + { + let f = Box::new(f); + let mut precommit_handlers = self.precommit_handlers.write(); + precommit_handlers.push(f); + } + + fn call_precommit_handlers(&self) -> Result<()> { + let current = self.current(); + let precommit_handlers = self.precommit_handlers.read(); + for precommit_func in precommit_handlers.iter().rev() { + precommit_func(current.clone())?; + } + Ok(()) + } + + /// Register a callback for the commit stage, + /// which is after the pre-commit stage. + /// + /// Committing a TX triggers first the pre-commit stage of the TX and then + /// the commit stage. The callbacks for the commit stage is not allowed + /// to fail. + pub fn register_commit_handler<F>(&self, f: F) + where + F: Fn(CurrentTx<'_>) + Send + Sync + 'static, + { + let f = Box::new(f); + let mut commit_handlers = self.commit_handlers.write(); + commit_handlers.push(f); + } + + fn call_commit_handlers(&self) { + let current = self.current(); + let commit_handlers = self.commit_handlers.read(); + for commit_func in commit_handlers.iter().rev() { + commit_func(current.clone()) + } + } + + /// Register a callback for the abort stage. + /// + /// A TX enters the abort stage when the TX is aborted by the user + /// (via `Tx::abort`) or by a callback in the pre-commit stage. + pub fn register_abort_handler<F>(&self, f: F) + where + F: Fn(CurrentTx<'_>) + Send + Sync + 'static, + { + let f = Box::new(f); + let mut abort_handlers = self.abort_handlers.write(); + abort_handlers.push(f); + } + + fn call_abort_handlers(&self) { + let current = self.current(); + let abort_handlers = self.abort_handlers.read(); + for abort_func in abort_handlers.iter().rev() { + abort_func(current.clone()) + } + } +} + +/// A transaction. +pub struct Tx { + id: TxId, + provider: Weak<TxProvider>, + data_map: HashMap<TypeId, Box<dyn Any + Send + Sync>>, + status: TxStatus, + is_accessing_data: AtomicBool, +} + +impl Tx { + fn new(provider: Weak<TxProvider>) -> Self { + static NEXT_ID: AtomicU64 = AtomicU64::new(0); + + Self { + id: NEXT_ID.fetch_add(1, Ordering::Release), + provider, + data_map: HashMap::new(), + status: TxStatus::Ongoing, + is_accessing_data: AtomicBool::new(false), + } + } + + /// Returns the TX ID. + pub fn id(&self) -> TxId { + self.id + } + + /// Returns the status of the TX. + pub fn status(&self) -> TxStatus { + self.status + } + + /// Sets the status of the Tx. + pub fn set_status(&mut self, status: TxStatus) { + self.status = status; + } + + fn provider(&self) -> Arc<TxProvider> { + self.provider.upgrade().unwrap() + } + + fn data<T>(&mut self) -> &T + where + T: TxData, + { + self.data_mut::<T>() + } + + fn data_mut<T>(&mut self) -> &mut T + where + T: TxData, + { + let exists = self.data_map.contains_key(&TypeId::of::<T>()); + if !exists { + // Slow path, need to initialize the data + let provider = self.provider(); + let data: T = provider.init_data::<T>(); + self.data_map.insert(TypeId::of::<T>(), Box::new(data)); + } + + // Fast path + self.data_map + .get_mut(&TypeId::of::<T>()) + .unwrap() + .downcast_mut::<T>() + .unwrap() + } +} + +impl Drop for Tx { + fn drop(&mut self) { + assert!( + self.status() != TxStatus::Ongoing, + "transactions must be committed or aborted explicitly" + ); + } +} + +/// The status of a transaction. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TxStatus { + Ongoing, + Committed, + Aborted, +} + +/// The ID of a transaction. +pub type TxId = u64; + +/// Per-transaction data. +/// +/// Using `TxProvider::register_data_initiailzer` to inject per-transaction data +/// and using `CurrentTx::data_with` or `CurrentTx::data_mut_with` to access +/// per-transaction data. +pub trait TxData: Any + Send + Sync {} + +#[cfg(test)] +mod tests { + use alloc::collections::BTreeSet; + + use super::*; + + /// `Db<T>` is a toy implementation of in-memory database for + /// a set of items of type `T`. + /// + /// The most interesting feature of `Db<T>` is the support + /// of transactions. All queries and insertions to the database must + /// be performed within transactions. These transactions ensure + /// the atomicity of insertions even in the presence of concurrent execution. + /// If transactions are aborted, their changes won't take effect. + /// + /// The main limitation of `Db<T>` is that it only supports + /// querying and inserting items, but not deleting. + /// The lack of support of deletions rules out the possibilities + /// of concurrent transactions conflicting with each other. + pub struct Db<T> { + all_items: Arc<Mutex<BTreeSet<T>>>, + tx_provider: Arc<TxProvider>, + } + + struct DbUpdate<T> { + new_items: BTreeSet<T>, + } + + impl<T: 'static> TxData for DbUpdate<T> {} + + impl<T> Db<T> + where + T: Ord + 'static, + { + /// Creates an empty database. + pub fn new() -> Self { + let new_self = Self { + all_items: Arc::new(Mutex::new(BTreeSet::new())), + tx_provider: TxProvider::new(), + }; + + new_self + .tx_provider + .register_data_initializer(Box::new(|| DbUpdate { + new_items: BTreeSet::<T>::new(), + })); + new_self.tx_provider.register_commit_handler({ + let all_items = new_self.all_items.clone(); + move |mut current: CurrentTx<'_>| { + current.data_mut_with(|update: &mut DbUpdate<T>| { + let mut all_items = all_items.lock(); + all_items.append(&mut update.new_items); + }); + } + }); + + new_self + } + + /// Creates a new DB transaction. + pub fn new_tx(&self) -> CurrentTx<'_> { + self.tx_provider.new_tx() + } + + /// Returns whether an item is contained. + /// + /// # Transaction + /// + /// This method must be called within the context of a transaction. + pub fn contains(&self, item: &T) -> bool { + let is_new_item = { + let current_tx = self.tx_provider.current(); + current_tx.data_with(|update: &DbUpdate<T>| update.new_items.contains(item)) + }; + if is_new_item { + return true; + } + + let all_items = self.all_items.lock(); + all_items.contains(item) + } + + /// Inserts a new item into the DB. + /// + /// # Transaction + /// + /// This method must be called within the context of a transaction. + pub fn insert(&self, item: T) { + let all_items = self.all_items.lock(); + if all_items.contains(&item) { + return; + } + + let mut current_tx = self.tx_provider.current(); + current_tx.data_mut_with(|update: &mut DbUpdate<_>| { + update.new_items.insert(item); + }); + } + + /// Collects all items of the DB. + /// + /// # Transaction + /// + /// This method must be called within the context of a transaction. + pub fn collect(&self) -> Vec<T> + where + T: Copy, + { + let all_items = self.all_items.lock(); + let current_tx = self.tx_provider.current(); + current_tx.data_with(|update: &DbUpdate<T>| { + all_items.union(&update.new_items).cloned().collect() + }) + } + + /// Returns the number of items in the DB. + /// + /// # Transaction + /// + /// This method must be called within the context of a transaction. + pub fn len(&self) -> usize { + let all_items = self.all_items.lock(); + let current_tx = self.tx_provider.current(); + let new_items_len = current_tx.data_with(|update: &DbUpdate<T>| update.new_items.len()); + all_items.len() + new_items_len + } + } + + #[test] + fn commit_takes_effect() { + let db: Db<u32> = Db::new(); + let items = vec![1, 2, 3]; + new_tx_and_insert_items::<u32, alloc::vec::IntoIter<u32>>(&db, items.clone().into_iter()) + .commit() + .unwrap(); + assert!(collect_items(&db) == items); + } + + #[test] + fn abort_has_no_effect() { + let db: Db<u32> = Db::new(); + let items = vec![1, 2, 3]; + new_tx_and_insert_items::<u32, alloc::vec::IntoIter<u32>>(&db, items.into_iter()).abort(); + assert!(collect_items(&db).len() == 0); + } + + fn new_tx_and_insert_items<T, I>(db: &Db<T>, new_items: I) -> Tx + where + I: Iterator<Item = T>, + T: Copy + Ord + 'static, + { + let mut tx = db.new_tx(); + tx.context(move || { + for new_item in new_items { + db.insert(new_item); + } + }); + tx + } + + fn collect_items<T>(db: &Db<T>) -> Vec<T> + where + T: Copy + Ord + 'static, + { + let mut tx = db.new_tx(); + let items = tx.context(|| db.collect()); + tx.commit().unwrap(); + items + } +} diff --git a/kernel/comps/mlsdisk/src/util/bitmap.rs b/kernel/comps/mlsdisk/src/util/bitmap.rs new file mode 100644 index 00000000..3ec513c5 --- /dev/null +++ b/kernel/comps/mlsdisk/src/util/bitmap.rs @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::ops::Index; + +use bittle::{Bits, BitsMut}; +use serde::{Deserialize, Serialize}; + +use crate::prelude::*; + +/// A compact array of bits. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct BitMap { + bits: Vec<u64>, + nbits: usize, +} + +impl BitMap { + /// The one bit represents `true`. + const ONE: bool = true; + + /// The zero bit represents `false`. + const ZERO: bool = false; + + /// Create a new `BitMap` by repeating the `value` for the desired length. + pub fn repeat(value: bool, nbits: usize) -> Self { + let vec_len = nbits.div_ceil(64); + let mut bits = Vec::with_capacity(vec_len); + if value == Self::ONE { + bits.resize(vec_len, !0u64); + } else { + bits.resize(vec_len, 0u64); + } + + // Set the unused bits in the last u64 with zero. + if nbits % 64 != 0 { + bits[vec_len - 1] + .iter_ones() + .filter(|index| (*index as usize) >= nbits % 64) + .for_each(|index| bits[vec_len - 1].clear_bit(index)); + } + + Self { bits, nbits } + } + + /// Return the total number of bits. + pub fn len(&self) -> usize { + self.nbits + } + + fn check_index(&self, index: usize) { + if index >= self.len() { + panic!( + "bitmap index {} is out of range, total bits {}", + index, self.nbits, + ); + } + } + + /// Test if the given bit is set. + /// + /// Return `true` if the given bit is one bit. + /// + /// # Panics + /// + /// The `index` must be within the total number of bits. Otherwise, this method panics. + pub fn test_bit(&self, index: usize) -> bool { + self.check_index(index); + self.bits.test_bit(index as _) + } + + /// Set the given bit with one bit. + /// + /// # Panics + /// + /// The `index` must be within the total number of bits. Otherwise, this method panics. + pub fn set_bit(&mut self, index: usize) { + self.check_index(index); + self.bits.set_bit(index as _); + } + + /// Clear the given bit with zero bit. + /// + /// # Panics + /// + /// The `index` must be within the total number of bits. Otherwise, this method panics. + pub fn clear_bit(&mut self, index: usize) { + self.check_index(index); + self.bits.clear_bit(index as _) + } + + /// Set the given bit with `value`. + /// + /// One bit is set for `true`, and zero bit for `false`. + /// + /// # Panics + /// + /// The `index` must be within the total number of bits. Otherwise, this method panics. + pub fn set(&mut self, index: usize, value: bool) { + if value == Self::ONE { + self.set_bit(index); + } else { + self.clear_bit(index); + } + } + + fn bits_not_in_use(&self) -> usize { + self.bits.len() * 64 - self.nbits + } + + /// Get the number of one bits in the bitmap. + pub fn count_ones(&self) -> usize { + self.bits.count_ones() as _ + } + + /// Get the number of zero bits in the bitmap. + pub fn count_zeros(&self) -> usize { + let total_zeros = self.bits.count_zeros() as usize; + total_zeros - self.bits_not_in_use() + } + + /// Find the index of the first one bit, starting from the given index (inclusively). + /// + /// Return `None` if no one bit is found. + /// + /// # Panics + /// + /// The `from` index must be within the total number of bits. Otherwise, this method panics. + pub fn first_one(&self, from: usize) -> Option<usize> { + self.check_index(from); + let first_u64_index = from / 64; + + self.bits[first_u64_index..] + .iter_ones() + .map(|index| first_u64_index * 64 + (index as usize)) + .find(|&index| index >= from) + } + + /// Find `count` indexes of the first one bits, starting from the given index (inclusively). + /// + /// Return `None` if fewer than `count` one bits are found. + /// + /// # Panics + /// + /// The `from + count` index must be within the total number of bits. Otherwise, this method panics. + pub fn first_ones(&self, from: usize, count: usize) -> Option<Vec<usize>> { + self.check_index(from + count - 1); + let first_u64_index = from / 64; + + let ones: Vec<_> = self.bits[first_u64_index..] + .iter_ones() + .map(|index| first_u64_index * 64 + (index as usize)) + .filter(|&index| index >= from) + .take(count) + .collect(); + if ones.len() == count { + Some(ones) + } else { + None + } + } + + /// Find the index of the last one bit. + /// + /// Return `None` if no one bit is found. + pub fn last_one(&self) -> Option<usize> { + self.bits + .iter_ones() + .rev() + .map(|index| index as usize) + .next() + } + + /// Find the index of the first zero bit, starting from the given index (inclusively). + /// + /// Return `None` if no zero bit is found. + /// + /// # Panics + /// + /// The `from` index must be within the total number of bits. Otherwise, this method panics. + pub fn first_zero(&self, from: usize) -> Option<usize> { + self.check_index(from); + let first_u64_index = from / 64; + + self.bits[first_u64_index..] + .iter_zeros() + .map(|index| first_u64_index * 64 + (index as usize)) + .find(|&index| index >= from && index < self.len()) + } + + /// Find `count` indexes of the first zero bits, starting from the given index (inclusively). + /// + /// Return `None` if fewer than `count` zero bits are found. + /// + /// # Panics + /// + /// The `from + count` index must be within the total number of bits. Otherwise, this method panics. + pub fn first_zeros(&self, from: usize, count: usize) -> Option<Vec<usize>> { + self.check_index(from + count - 1); + let first_u64_index = from / 64; + + let zeros: Vec<_> = self.bits[first_u64_index..] + .iter_zeros() + .map(|index| first_u64_index * 64 + (index as usize)) + .filter(|&index| index >= from && index < self.len()) + .take(count) + .collect(); + if zeros.len() == count { + Some(zeros) + } else { + None + } + } + + /// Find the index of the last zero bit. + /// + /// Return `None` if no zero bit is found. + pub fn last_zero(&self) -> Option<usize> { + self.bits + .iter_zeros() + .rev() + .skip(self.bits_not_in_use()) + .map(|index| index as usize) + .next() + } +} + +impl Index<usize> for BitMap { + type Output = bool; + + fn index(&self, index: usize) -> &Self::Output { + if self.test_bit(index) { + &BitMap::ONE + } else { + &BitMap::ZERO + } + } +} + +#[cfg(test)] +mod tests { + use super::BitMap; + + #[test] + fn all_true() { + let bm = BitMap::repeat(true, 100); + assert_eq!(bm.len(), 100); + assert_eq!(bm.count_ones(), 100); + assert_eq!(bm.count_zeros(), 0); + } + + #[test] + fn all_false() { + let bm = BitMap::repeat(false, 100); + assert_eq!(bm.len(), 100); + assert_eq!(bm.count_ones(), 0); + assert_eq!(bm.count_zeros(), 100); + } + + #[test] + fn bit_ops() { + let mut bm = BitMap::repeat(false, 100); + + assert_eq!(bm.count_ones(), 0); + + bm.set_bit(32); + assert_eq!(bm.count_ones(), 1); + assert_eq!(bm.test_bit(32), true); + + bm.set(64, true); + assert_eq!(bm.count_ones(), 2); + assert_eq!(bm.test_bit(64), true); + + bm.clear_bit(32); + assert_eq!(bm.count_ones(), 1); + assert_eq!(bm.test_bit(32), false); + + bm.set(64, false); + assert_eq!(bm.count_ones(), 0); + assert_eq!(bm.test_bit(64), false); + } + + #[test] + fn find_first_last() { + let mut bm = BitMap::repeat(false, 100); + bm.set_bit(64); + assert_eq!(bm.first_one(0), Some(64)); + assert_eq!(bm.first_one(64), Some(64)); + assert_eq!(bm.first_one(65), None); + assert_eq!(bm.first_ones(0, 1), Some(vec![64])); + assert_eq!(bm.first_ones(0, 2), None); + assert_eq!(bm.last_one(), Some(64)); + + let mut bm = BitMap::repeat(true, 100); + bm.clear_bit(64); + assert_eq!(bm.first_zero(0), Some(64)); + assert_eq!(bm.first_zero(64), Some(64)); + assert_eq!(bm.first_zero(65), None); + assert_eq!(bm.first_zeros(0, 1), Some(vec![64])); + assert_eq!(bm.first_zeros(0, 2), None); + assert_eq!(bm.last_zero(), Some(64)); + } +} diff --git a/kernel/comps/mlsdisk/src/util/crypto.rs b/kernel/comps/mlsdisk/src/util/crypto.rs new file mode 100644 index 00000000..e17d1b1c --- /dev/null +++ b/kernel/comps/mlsdisk/src/util/crypto.rs @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::ops::Deref; + +use crate::prelude::Result; + +/// Random initialization for Key, Iv and Mac. +pub trait RandomInit: Default { + fn random() -> Self; +} + +/// Authenticated Encryption with Associated Data (AEAD) algorithm. +pub trait Aead { + type Key: Deref<Target = [u8]> + RandomInit; + type Iv: Deref<Target = [u8]> + RandomInit; + type Mac: Deref<Target = [u8]> + RandomInit; + + /// Encrypt plaintext referred by `input`, with a secret `Key`, + /// initialization vector `Iv` and additional associated data `aad`. + /// + /// If the operation succeed, the ciphertext will be written to `output` + /// and a message authentication code `Mac` will be returned. Or else, + /// return an `Error` on any fault. + fn encrypt( + &self, + input: &[u8], + key: &Self::Key, + iv: &Self::Iv, + aad: &[u8], + output: &mut [u8], + ) -> Result<Self::Mac>; + + /// Decrypt ciphertext referred by `input`, with a secret `Key` and + /// message authentication code `Mac`, initialization vector `Iv` and + /// additional associated data `aad`. + /// + /// If the operation succeed, the plaintext will be written to `output`. + /// Or else, return an `Error` on any fault. + fn decrypt( + &self, + input: &[u8], + key: &Self::Key, + iv: &Self::Iv, + aad: &[u8], + mac: &Self::Mac, + output: &mut [u8], + ) -> Result<()>; +} + +/// Symmetric key cipher algorithm. +pub trait Skcipher { + type Key: Deref<Target = [u8]> + RandomInit; + type Iv: Deref<Target = [u8]> + RandomInit; + + /// Encrypt plaintext referred by `input`, with a secret `Key` and + /// initialization vector `Iv`. + /// + /// If the operation succeed, the ciphertext will be written to `output`. + /// Or else, return an `Error` on any fault. + fn encrypt( + &self, + input: &[u8], + key: &Self::Key, + iv: &Self::Iv, + output: &mut [u8], + ) -> Result<()>; + + /// Decrypt ciphertext referred by `input` with a secret `Key` and + /// initialization vector `Iv`. + /// + /// If the operation succeed, the plaintext will be written to `output`. + /// Or else, return an `Error` on any fault. + fn decrypt( + &self, + input: &[u8], + key: &Self::Key, + iv: &Self::Iv, + output: &mut [u8], + ) -> Result<()>; +} + +/// Random number generator. +pub trait Rng { + /// Create an instance, with `seed` to provide secure entropy. + fn new(seed: &[u8]) -> Self; + + /// Fill `dest` with random bytes. + fn fill_bytes(&self, dest: &mut [u8]) -> Result<()>; +} diff --git a/kernel/comps/mlsdisk/src/util/lazy_delete.rs b/kernel/comps/mlsdisk/src/util/lazy_delete.rs new file mode 100644 index 00000000..79a0c194 --- /dev/null +++ b/kernel/comps/mlsdisk/src/util/lazy_delete.rs @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::{ + fmt, + ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, Ordering}, +}; + +use crate::prelude::*; + +/// An object that may be deleted lazily. +/// +/// Lazy-deletion is a technique to postpone the real deletion of an object. +/// This technique allows an object to remain usable even after a decision +/// to delete the object has been made. Of course. After the "real" deletion +/// is carried out, the object will no longer be usable. +/// +/// A classic example is file deletion in UNIX file systems. +/// +/// ```ignore +/// int fd = open("path/to/my_file", O_RDONLY); +/// unlink("path/to/my_file"); +/// // fd is still valid after unlink +/// ``` +/// +/// `LazyDelete<T>` enables lazy deletion of any object of `T`. +/// Here is a simple example. +/// +/// ``` +/// use sworndisk_v2::lazy_delete::*; +/// +/// let lazy_delete_u32 = LazyDelete::new(123_u32, |obj| { +/// println!("the real deletion happens in this closure"); +/// }); +/// +/// // The object is still usable after it is deleted (lazily) +/// LazyDelete::delete(&lazy_delete_u32); +/// assert!(*lazy_delete_u32 == 123); +/// +/// // The deletion operation will be carried out when it is dropped +/// drop(lazy_delete_u32); +/// ``` +#[allow(clippy::type_complexity)] +pub struct LazyDelete<T> { + obj: T, + is_deleted: AtomicBool, + delete_fn: Option<Box<dyn FnOnce(&mut T) + Send + Sync>>, +} + +impl<T: fmt::Debug> fmt::Debug for LazyDelete<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LazyDelete") + .field("obj", &self.obj) + .field("is_deleted", &Self::is_deleted(self)) + .finish() + } +} + +impl<T> LazyDelete<T> { + /// Creates a new instance of `LazyDelete`. + /// + /// The `delete_fn` will be called only if this instance of `LazyDelete` is + /// marked deleted by the `delete` method and only when this instance + /// of `LazyDelete` is dropped. + pub fn new<F: FnOnce(&mut T) + Send + Sync + 'static>(obj: T, delete_fn: F) -> Self { + Self { + obj, + is_deleted: AtomicBool::new(false), + delete_fn: Some(Box::new(delete_fn) as _), + } + } + + /// Mark this instance deleted. + pub fn delete(this: &Self) { + this.is_deleted.store(true, Ordering::Release); + } + + /// Returns whether this instance has been marked deleted. + pub fn is_deleted(this: &Self) -> bool { + this.is_deleted.load(Ordering::Acquire) + } +} + +impl<T> Deref for LazyDelete<T> { + type Target = T; + + fn deref(&self) -> &T { + &self.obj + } +} + +impl<T> DerefMut for LazyDelete<T> { + fn deref_mut(&mut self) -> &mut T { + &mut self.obj + } +} + +impl<T> Drop for LazyDelete<T> { + fn drop(&mut self) { + if Self::is_deleted(self) { + let delete_fn = self.delete_fn.take().unwrap(); + (delete_fn)(&mut self.obj); + } + } +} diff --git a/kernel/comps/mlsdisk/src/util/mod.rs b/kernel/comps/mlsdisk/src/util/mod.rs new file mode 100644 index 00000000..9c1a50ce --- /dev/null +++ b/kernel/comps/mlsdisk/src/util/mod.rs @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Utilities. +mod bitmap; +mod crypto; +mod lazy_delete; + +pub use self::{ + bitmap::BitMap, + crypto::{Aead, RandomInit, Rng, Skcipher}, + lazy_delete::LazyDelete, +}; + +/// Aligns `x` up to the next multiple of `align`. +pub(crate) const fn align_up(x: usize, align: usize) -> usize { + x.div_ceil(align) * align +} + +/// Aligns `x` down to the previous multiple of `align`. +pub(crate) const fn align_down(x: usize, align: usize) -> usize { + (x / align) * align +}