Rewrite cpio-decoder with Read trait as the input parameter

This commit is contained in:
LI Qing 2023-07-04 18:06:49 +08:00 committed by Tate, Hongliang Tian
parent 6d621dc4ef
commit 4c83ff9411
3 changed files with 140 additions and 100 deletions

View File

@ -6,4 +6,5 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
int-to-c-enum = { path = "../../libs/int-to-c-enum" }
int-to-c-enum = { path = "../../libs/int-to-c-enum" }
core2 = { version = "0.4", default_features = false, features = ["alloc"] }

View File

@ -5,7 +5,8 @@
//! ```rust
//! use cpio_decoder::CpioDecoder;
//!
//! let decoder = CpioDecoder::new(&[]);
//! let short_buffer: Vec<u8> = Vec::new();
//! let mut decoder = CpioDecoder::new(short_buffer.as_slice());
//! for entry_result in decoder.decode_entries() {
//! println!("The entry_result is: {:?}", entry_result);
//! }
@ -13,8 +14,15 @@
#![cfg_attr(not(test), no_std)]
#![forbid(unsafe_code)]
#![allow(dead_code)]
extern crate alloc;
use crate::error::{Error, Result};
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use core2::io::Read;
use int_to_c_enum::TryFromInt;
pub mod error;
@ -32,62 +40,64 @@ mod test;
///
/// All the fields in the header are ISO 646 (approximately ASCII) strings
/// of hexadecimal numbers, left padded, not NULL terminated.
pub struct CpioDecoder<'a> {
buffer: &'a [u8],
pub struct CpioDecoder<R> {
inner: R,
}
impl<'a> CpioDecoder<'a> {
impl<R> CpioDecoder<R>
where
R: Read,
{
/// create a decoder to decode the CPIO.
pub fn new(buffer: &'a [u8]) -> Self {
Self { buffer }
pub fn new(inner: R) -> Self {
Self { inner }
}
/// Return an iterator trying to decode the entries in the CPIO.
pub fn decode_entries(&'a self) -> CpioEntryIter<'a> {
CpioEntryIter::new(self)
pub fn decode_entries(&mut self) -> CpioEntryIter<R> {
CpioEntryIter::new(&mut self.inner)
}
}
/// An iterator over the results of CPIO entries.
///
/// It stops if reaches to the trailer entry or encounters an error.
pub struct CpioEntryIter<'a> {
buffer: &'a [u8],
offset: usize,
pub struct CpioEntryIter<'a, R> {
reader: &'a mut R,
is_error: bool,
}
impl<'a> CpioEntryIter<'a> {
fn new(decoder: &'a CpioDecoder) -> Self {
impl<'a, R> CpioEntryIter<'a, R>
where
R: Read,
{
fn new(reader: &'a mut R) -> Self {
Self {
buffer: decoder.buffer,
offset: 0,
reader,
is_error: false,
}
}
}
impl<'a> Iterator for CpioEntryIter<'a> {
type Item = Result<CpioEntry<'a>>;
impl<'a, R> Iterator for CpioEntryIter<'a, R>
where
R: Read,
{
type Item = Result<CpioEntry>;
fn next(&mut self) -> Option<Result<CpioEntry<'a>>> {
fn next(&mut self) -> Option<Result<CpioEntry>> {
// Stop to iterate entries if encounters an error.
if self.is_error {
return None;
}
let entry_result = if self.offset >= self.buffer.len() {
Err(Error::BufferShortError)
} else {
CpioEntry::new(&self.buffer[self.offset..])
};
let entry_result = CpioEntry::new(self.reader);
match &entry_result {
Ok(entry) => {
// A correct CPIO buffer must end with a trailer.
if entry.is_trailer() {
return None;
}
self.offset += entry.archive_offset();
}
Err(_) => {
self.is_error = true;
@ -99,43 +109,59 @@ impl<'a> Iterator for CpioEntryIter<'a> {
/// A file entry in the CPIO.
#[derive(Debug)]
pub struct CpioEntry<'a> {
pub struct CpioEntry {
metadata: FileMetadata,
name: &'a str,
data: &'a [u8],
name: String,
data: Vec<u8>,
}
impl<'a> CpioEntry<'a> {
fn new(bytes: &'a [u8]) -> Result<Self> {
impl CpioEntry {
fn new<R>(reader: &mut R) -> Result<Self>
where
R: Read,
{
let (metadata, name, data) = {
let header = Header::new(bytes)?;
let header = Header::new(reader)?;
let name = {
let bytes_remain = &bytes[HEADER_LEN..];
let name_size = read_hex_bytes_to_u32(header.name_size)? as usize;
if bytes_remain.len() < name_size {
return Err(Error::BufferShortError);
}
let name = core::ffi::CStr::from_bytes_with_nul(&bytes_remain[..name_size])
let name_size = read_hex_bytes_to_u32(&header.name_size)? as usize;
let mut name_bytes = vec![0u8; name_size];
reader
.read_exact(&mut name_bytes)
.map_err(|_| Error::BufferShortError)?;
let name = core::ffi::CStr::from_bytes_with_nul(&name_bytes)
.map_err(|_| Error::FileNameError)?;
name.to_str().map_err(|_| Error::Utf8Error)?
name.to_str().map_err(|_| Error::Utf8Error)?.to_string()
};
let metadata = if name == TRAILER_NAME {
Default::default()
} else {
FileMetadata::new(header)?
FileMetadata::new(&header)?
};
let data = {
let data_size = metadata.size as usize;
if data_size == 0 {
&[]
} else {
let data_offset = align_up(HEADER_LEN + name.len() + 1, 4);
if data_offset + data_size > bytes.len() {
return Err(Error::BufferShortError);
}
&bytes[data_offset..data_offset + data_size]
let pad_header_len = align_up_pad(header.len() + name.len() + 1, 4);
if pad_header_len > 0 {
let mut pad_buf = vec![0u8; pad_header_len];
reader
.read_exact(&mut pad_buf)
.map_err(|_| Error::BufferShortError)?;
}
let mut data: Vec<u8> = Vec::new();
let data_size = metadata.size as usize;
if data_size > 0 {
data.resize_with(data_size, Default::default);
reader
.read_exact(&mut data)
.map_err(|_| Error::BufferShortError)?;
}
data
};
let pad_data_len = align_up_pad(data.len(), 4);
if pad_data_len > 0 {
let mut pad_buf = vec![0u8; pad_data_len];
reader
.read_exact(&mut pad_buf)
.map_err(|_| Error::BufferShortError)?;
}
(metadata, name, data)
};
@ -153,20 +179,16 @@ impl<'a> CpioEntry<'a> {
/// The name of the file.
pub fn name(&self) -> &str {
self.name
&self.name
}
/// The data of the file.
pub fn data(&self) -> &[u8] {
self.data
&self.data
}
fn is_trailer(&self) -> bool {
self.name == TRAILER_NAME
}
fn archive_offset(&self) -> usize {
align_up(HEADER_LEN + self.name.len() + 1, 4) + align_up(self.data.len(), 4)
&self.name == TRAILER_NAME
}
}
@ -188,7 +210,7 @@ pub struct FileMetadata {
}
impl FileMetadata {
fn new(header: Header) -> Result<Self> {
fn new(header: &Header) -> Result<Self> {
const MODE_MASK: u32 = 0o7777;
const TYPE_MASK: u32 = 0o170000;
let raw_mode = read_hex_bytes_to_u32(&header.mode)?;
@ -296,51 +318,60 @@ impl Default for FileType {
}
}
const HEADER_LEN: usize = 110;
const MAGIC: &[u8] = b"070701";
const TRAILER_NAME: &str = "TRAILER!!!";
#[rustfmt::skip]
struct Header<'a> {
// magic: &'a [u8], // [u8; 6]
ino: &'a [u8], // [u8; 8]
mode: &'a [u8], // [u8; 8]
uid: &'a [u8], // [u8; 8]
gid: &'a [u8], // [u8; 8]
nlink: &'a [u8], // [u8; 8]
mtime: &'a [u8], // [u8; 8]
file_size: &'a [u8], // [u8; 8]
dev_maj: &'a [u8], // [u8; 8]
dev_min: &'a [u8], // [u8; 8]
rdev_maj: &'a [u8], // [u8; 8]
rdev_min: &'a [u8], // [u8; 8]
name_size: &'a [u8], // [u8; 8]
// chksum: &'a [u8], // [u8; 8]
struct Header {
magic: [u8; 6],
ino: [u8; 8],
mode: [u8; 8],
uid: [u8; 8],
gid: [u8; 8],
nlink: [u8; 8],
mtime: [u8; 8],
file_size: [u8; 8],
dev_maj: [u8; 8],
dev_min: [u8; 8],
rdev_maj: [u8; 8],
rdev_min: [u8; 8],
name_size: [u8; 8],
chksum: [u8; 8],
}
impl<'a> Header<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self> {
if bytes.len() < HEADER_LEN {
return Err(Error::BufferShortError);
}
let magic = &bytes[..6];
if magic != MAGIC {
impl Header {
pub fn new<R>(reader: &mut R) -> Result<Self>
where
R: Read,
{
let mut buf = vec![0u8; core::mem::size_of::<Self>()];
reader
.read_exact(&mut buf)
.map_err(|_| Error::BufferShortError)?;
let header = Self {
magic: <[u8; 6]>::try_from(&buf[0..6]).unwrap(),
ino: <[u8; 8]>::try_from(&buf[6..14]).unwrap(),
mode: <[u8; 8]>::try_from(&buf[14..22]).unwrap(),
uid: <[u8; 8]>::try_from(&buf[22..30]).unwrap(),
gid: <[u8; 8]>::try_from(&buf[30..38]).unwrap(),
nlink: <[u8; 8]>::try_from(&buf[38..46]).unwrap(),
mtime: <[u8; 8]>::try_from(&buf[46..54]).unwrap(),
file_size: <[u8; 8]>::try_from(&buf[54..62]).unwrap(),
dev_maj: <[u8; 8]>::try_from(&buf[62..70]).unwrap(),
dev_min: <[u8; 8]>::try_from(&buf[70..78]).unwrap(),
rdev_maj: <[u8; 8]>::try_from(&buf[78..86]).unwrap(),
rdev_min: <[u8; 8]>::try_from(&buf[86..94]).unwrap(),
name_size: <[u8; 8]>::try_from(&buf[94..102]).unwrap(),
chksum: <[u8; 8]>::try_from(&buf[102..110]).unwrap(),
};
if header.magic != MAGIC {
return Err(Error::MagicError);
}
Ok(Self {
ino: &bytes[6..14],
mode: &bytes[14..22],
uid: &bytes[22..30],
gid: &bytes[30..38],
nlink: &bytes[38..46],
mtime: &bytes[46..54],
file_size: &bytes[54..62],
dev_maj: &bytes[62..70],
dev_min: &bytes[70..78],
rdev_maj: &bytes[78..86],
rdev_min: &bytes[86..94],
name_size: &bytes[94..102],
})
Ok(header)
}
fn len(&self) -> usize {
core::mem::size_of::<Self>()
}
}
@ -351,6 +382,10 @@ fn read_hex_bytes_to_u32(bytes: &[u8]) -> Result<u32> {
Ok(num)
}
fn align_up_pad(size: usize, align: usize) -> usize {
align_up(size, align) - size
}
fn align_up(size: usize, align: usize) -> usize {
debug_assert!(align >= 2 && align.is_power_of_two());
(size + align - 1) & !(align - 1)

View File

@ -24,10 +24,9 @@ fn test_decoder() {
output.stdout
};
let decoder = CpioDecoder::new(&buffer);
assert!(decoder.decode_entries().count() > 3);
assert!(CpioDecoder::new(buffer.as_slice()).decode_entries().count() > 3);
let mut decoder = CpioDecoder::new(buffer.as_slice());
for (idx, entry_result) in decoder.decode_entries().enumerate() {
assert!(entry_result.is_ok());
let entry = entry_result.unwrap();
if idx == 0 {
assert!(entry.name() == ".");
@ -40,7 +39,11 @@ fn test_decoder() {
assert!(entry.metadata().ino() > 0);
}
if idx == 2 {
assert!(entry.name() == "src/lib.rs");
assert!(
entry.name() == "src/lib.rs"
|| entry.name() == "src/test.rs"
|| entry.name() == "src/error.rs"
);
assert!(entry.metadata().file_type() == FileType::File);
assert!(entry.metadata().ino() > 0);
}
@ -49,7 +52,8 @@ fn test_decoder() {
#[test]
fn test_short_buffer() {
let decoder = CpioDecoder::new(&[]);
let short_buffer: Vec<u8> = Vec::new();
let mut decoder = CpioDecoder::new(short_buffer.as_slice());
for entry_result in decoder.decode_entries() {
assert!(entry_result.is_err());
assert!(entry_result.err() == Some(Error::BufferShortError));
@ -59,7 +63,7 @@ fn test_short_buffer() {
#[test]
fn test_invalid_buffer() {
let buffer: &[u8] = b"invalidmagic.invalidmagic.invalidmagic.invalidmagic.invalidmagic.invalidmagic.invalidmagic.invalidmagic.invalidmagic.invalidmagic";
let decoder = CpioDecoder::new(buffer);
let mut decoder = CpioDecoder::new(buffer);
for entry_result in decoder.decode_entries() {
assert!(entry_result.is_err());
assert!(entry_result.err() == Some(Error::MagicError));