Skip to content

Commit

Permalink
Engine::internal_decode now returns DecodeSliceError
Browse files Browse the repository at this point in the history
Implementations must now precisely, not conservatively, return an error when the output length is too small.
  • Loading branch information
marshallpierce committed Mar 2, 2024
1 parent a8a60f4 commit 9e9c7ab
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 159 deletions.
3 changes: 1 addition & 2 deletions benches/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ fn do_encode_bench_slice(b: &mut Bencher, &size: &usize) {
fn do_encode_bench_stream(b: &mut Bencher, &size: &usize) {
let mut v: Vec<u8> = Vec::with_capacity(size);
fill(&mut v);
let mut buf = Vec::new();
let mut buf = Vec::with_capacity(size * 2);

buf.reserve(size * 2);
b.iter(|| {
buf.clear();
let mut stream_enc = write::EncoderWriter::new(&mut buf, &STANDARD);
Expand Down
4 changes: 1 addition & 3 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ impl error::Error for DecodeError {}
pub enum DecodeSliceError {
/// A [DecodeError] occurred
DecodeError(DecodeError),
/// The provided slice _may_ be too small.
///
/// The check is conservative (assumes the last triplet of output bytes will all be needed).
/// The provided slice is too small.
OutputSliceTooSmall,
}

Expand Down
75 changes: 51 additions & 24 deletions src/engine/general_purpose/decode.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
DecodeError, PAD_BYTE,
DecodeError, DecodeSliceError, PAD_BYTE,
};

#[doc(hidden)]
pub struct GeneralPurposeEstimate {
/// input len % 4
rem: usize,
conservative_len: usize,
conservative_decoded_len: usize,
}

impl GeneralPurposeEstimate {
pub(crate) fn new(encoded_len: usize) -> Self {
let rem = encoded_len % 4;
Self {
rem,
conservative_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
}
}
}

impl DecodeEstimate for GeneralPurposeEstimate {
fn decoded_len_estimate(&self) -> usize {
self.conservative_len
self.conservative_decoded_len
}
}

Expand All @@ -38,25 +39,9 @@ pub(crate) fn decode_helper(
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<DecodeMetadata, DecodeError> {
// detect a trailing invalid byte, like a newline, as a user convenience
if estimate.rem == 1 {
let last_byte = input[input.len() - 1];
// exclude pad bytes; might be part of padding that extends from earlier in the input
if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
return Err(DecodeError::InvalidByte(input.len() - 1, last_byte));
}
}

// skip last quad, even if it's complete, as it may have padding
let input_complete_nonterminal_quads_len = input
.len()
.saturating_sub(estimate.rem)
// if rem was 0, subtract 4 to avoid padding
.saturating_sub((estimate.rem == 0) as usize * 4);
debug_assert!(
input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
);
) -> Result<DecodeMetadata, DecodeSliceError> {
let input_complete_nonterminal_quads_len =
complete_quads_len(input, estimate.rem, output.len(), decode_table)?;

const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
Expand Down Expand Up @@ -135,6 +120,48 @@ pub(crate) fn decode_helper(
)
}

/// Returns the length of complete quads, except for the last one, even if it is complete.
///
/// Returns an error if the output len is not big enough for decoding those complete quads, or if
/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte.
///
/// - `input` is the base64 input
/// - `input_len_rem` is input len % 4
/// - `output_len` is the length of the output slice
pub(crate) fn complete_quads_len(
input: &[u8],
input_len_rem: usize,
output_len: usize,
decode_table: &[u8; 256],
) -> Result<usize, DecodeSliceError> {
debug_assert!(input.len() % 4 == input_len_rem);

// detect a trailing invalid byte, like a newline, as a user convenience
if input_len_rem == 1 {
let last_byte = input[input.len() - 1];
// exclude pad bytes; might be part of padding that extends from earlier in the input
if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
}
};

// skip last quad, even if it's complete, as it may have padding
let input_complete_nonterminal_quads_len = input
.len()
.saturating_sub(input_len_rem)
// if rem was 0, subtract 4 to avoid padding
.saturating_sub((input_len_rem == 0) as usize * 4);
debug_assert!(
input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
);

// check that everything except the last quad handled by decode_suffix will fit
if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
return Err(DecodeSliceError::OutputSliceTooSmall);
};
Ok(input_complete_nonterminal_quads_len)
}

/// Decode 8 bytes of input into 6 bytes of output.
///
/// `input` is the 8 bytes to decode.
Expand Down Expand Up @@ -321,7 +348,7 @@ mod tests {
let len_128 = encoded_len as u128;

let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_len as u128);
assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_decoded_len as u128);
})
}
}
59 changes: 25 additions & 34 deletions src/engine/general_purpose/decode_suffix.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode},
DecodeError, PAD_BYTE,
DecodeError, DecodeSliceError, PAD_BYTE,
};

/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided
Expand All @@ -16,11 +16,11 @@ pub(crate) fn decode_suffix(
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<DecodeMetadata, DecodeError> {
) -> Result<DecodeMetadata, DecodeSliceError> {
debug_assert!((input.len() - input_index) <= 4);

// Decode any leftovers that might not be a complete input chunk of 8 bytes.
// Use a u64 as a stack-resident 8 byte buffer.
// Decode any leftovers that might not be a complete input chunk of 4 bytes.
// Use a u32 as a stack-resident 4 byte buffer.
let mut morsels_in_leftover = 0;
let mut padding_bytes_count = 0;
// offset from input_index
Expand All @@ -44,22 +44,14 @@ pub(crate) fn decode_suffix(
// may be treated as an error condition.

if leftover_index < 2 {
// Check for case #2.
let bad_padding_index = input_index
+ if padding_bytes_count > 0 {
// If we've already seen padding, report the first padding index.
// This is to be consistent with the normal decode logic: it will report an
// error on the first padding character (since it doesn't expect to see
// anything but actual encoded data).
// This could only happen if the padding started in the previous quad since
// otherwise this case would have been hit at i == 4 if it was the same
// quad.
first_padding_offset
} else {
// haven't seen padding before, just use where we are now
leftover_index
};
return Err(DecodeError::InvalidByte(bad_padding_index, b));
// Check for error #2.
// Either the previous byte was padding, in which case we would have already hit
// this case, or it wasn't, in which case this is the first such error.
debug_assert!(
leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0)
);
let bad_padding_index = input_index + leftover_index;
return Err(DecodeError::InvalidByte(bad_padding_index, b).into());
}

if padding_bytes_count == 0 {
Expand All @@ -75,10 +67,9 @@ pub(crate) fn decode_suffix(
// non-suffix '=' in trailing chunk either. Report error as first
// erroneous padding.
if padding_bytes_count > 0 {
return Err(DecodeError::InvalidByte(
input_index + first_padding_offset,
PAD_BYTE,
));
return Err(
DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(),
);
}

last_symbol = b;
Expand All @@ -87,7 +78,7 @@ pub(crate) fn decode_suffix(
// Pack the leftovers from left to right.
let morsel = decode_table[b as usize];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(input_index + leftover_index, b));
return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into());
}

morsels[morsels_in_leftover] = morsel;
Expand All @@ -97,24 +88,22 @@ pub(crate) fn decode_suffix(
// If there was 1 trailing byte, and it was valid, and we got to this point without hitting
// an invalid byte, now we can report invalid length
if !input.is_empty() && morsels_in_leftover < 2 {
return Err(DecodeError::InvalidLength(
input_index + morsels_in_leftover,
));
return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into());
}

match padding_mode {
DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ }
DecodePaddingMode::RequireCanonical => {
// allow empty input
if (padding_bytes_count + morsels_in_leftover) % 4 != 0 {
return Err(DecodeError::InvalidPadding);
return Err(DecodeError::InvalidPadding.into());
}
}
DecodePaddingMode::RequireNone => {
if padding_bytes_count > 0 {
// check at the end to make sure we let the cases of padding that should be InvalidByte
// get hit
return Err(DecodeError::InvalidPadding);
return Err(DecodeError::InvalidPadding.into());
}
}
}
Expand All @@ -127,7 +116,7 @@ pub(crate) fn decode_suffix(
// bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a
// mask based on how many bits are used for just the canonical encoding, and optionally
// error if any other bits are set. In the example of one encoded byte -> 2 symbols,
// 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and
// 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and
// useless since there are no more symbols to provide the necessary 4 additional bits
// to finish the second original byte.

Expand All @@ -147,16 +136,18 @@ pub(crate) fn decode_suffix(
return Err(DecodeError::InvalidLastSymbol(
input_index + morsels_in_leftover - 1,
last_symbol,
));
)
.into());
}

// Strangely, this approach benchmarks better than writing bytes one at a time,
// or copy_from_slice into output.
for _ in 0..leftover_bytes_to_append {
let hi_byte = (leftover_num >> 24) as u8;
leftover_num <<= 8;
// TODO use checked writes
output[output_index] = hi_byte;
*output
.get_mut(output_index)
.ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte;
output_index += 1;
}

Expand Down
6 changes: 3 additions & 3 deletions src/engine/general_purpose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use crate::{
alphabet,
alphabet::Alphabet,
engine::{Config, DecodeMetadata, DecodePaddingMode},
DecodeError,
DecodeSliceError,
};
use core::convert::TryInto;

mod decode;
pub(crate) mod decode;
pub(crate) mod decode_suffix;

pub use decode::GeneralPurposeEstimate;
Expand Down Expand Up @@ -173,7 +173,7 @@ impl super::Engine for GeneralPurpose {
input: &[u8],
output: &mut [u8],
estimate: Self::DecodeEstimate,
) -> Result<DecodeMetadata, DecodeError> {
) -> Result<DecodeMetadata, DecodeSliceError> {
decode::decode_helper(
input,
estimate,
Expand Down
41 changes: 26 additions & 15 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,13 @@ pub trait Engine: Send + Sync {
///
/// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as
/// errors unless the engine is configured otherwise.
///
/// # Panics
///
/// Panics if `output` is too small.
#[doc(hidden)]
fn internal_decode(
&self,
input: &[u8],
output: &mut [u8],
decode_estimate: Self::DecodeEstimate,
) -> Result<DecodeMetadata, DecodeError>;
) -> Result<DecodeMetadata, DecodeSliceError>;

/// Returns the config for this engine.
fn config(&self) -> &Self::Config;
Expand Down Expand Up @@ -253,7 +249,13 @@ pub trait Engine: Send + Sync {
let mut buffer = vec![0; estimate.decoded_len_estimate()];

let bytes_written = engine
.internal_decode(input_bytes, &mut buffer, estimate)?
.internal_decode(input_bytes, &mut buffer, estimate)
.map_err(|e| match e {
DecodeSliceError::DecodeError(e) => e,
DecodeSliceError::OutputSliceTooSmall => {
unreachable!("Vec is sized conservatively")
}
})?
.decoded_len;

buffer.truncate(bytes_written);
Expand Down Expand Up @@ -318,7 +320,13 @@ pub trait Engine: Send + Sync {
let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];

let bytes_written = engine
.internal_decode(input_bytes, buffer_slice, estimate)?
.internal_decode(input_bytes, buffer_slice, estimate)
.map_err(|e| match e {
DecodeSliceError::DecodeError(e) => e,
DecodeSliceError::OutputSliceTooSmall => {
unreachable!("Vec is sized conservatively")
}
})?
.decoded_len;

buffer.truncate(starting_output_len + bytes_written);
Expand Down Expand Up @@ -354,15 +362,12 @@ pub trait Engine: Send + Sync {
where
E: Engine + ?Sized,
{
let estimate = engine.internal_decoded_len_estimate(input_bytes.len());

if output.len() < estimate.decoded_len_estimate() {
return Err(DecodeSliceError::OutputSliceTooSmall);
}

engine
.internal_decode(input_bytes, output, estimate)
.map_err(|e| e.into())
.internal_decode(
input_bytes,
output,
engine.internal_decoded_len_estimate(input_bytes.len()),
)
.map(|dm| dm.decoded_len)
}

Expand Down Expand Up @@ -400,6 +405,12 @@ pub trait Engine: Send + Sync {
engine.internal_decoded_len_estimate(input_bytes.len()),
)
.map(|dm| dm.decoded_len)
.map_err(|e| match e {
DecodeSliceError::DecodeError(e) => e,
DecodeSliceError::OutputSliceTooSmall => {
panic!("Output slice is too small")
}
})
}

inner(self, input.as_ref(), output)
Expand Down
Loading

0 comments on commit 9e9c7ab

Please sign in to comment.