From 0c7005fd93d0aaf158ec1161cbe99ab293aec2a2 Mon Sep 17 00:00:00 2001 From: Andrew Gallant Date: Wed, 26 Jun 2019 18:32:57 -0400 Subject: [PATCH] idioms: use pub(crate), use Error::source --- src/byte_record.rs | 149 +++++++++++++++++++++---------------------- src/deserializer.rs | 6 +- src/error.rs | 78 ++++++++-------------- src/reader.rs | 20 +++--- src/serializer.rs | 4 +- src/string_record.rs | 94 +++++++++++++-------------- src/writer.rs | 8 +-- 7 files changed, 159 insertions(+), 200 deletions(-) diff --git a/src/byte_record.rs b/src/byte_record.rs index d22a105..db5e511 100644 --- a/src/byte_record.rs +++ b/src/byte_record.rs @@ -12,79 +12,6 @@ use crate::deserializer::deserialize_byte_record; use crate::error::{new_utf8_error, Result, Utf8Error}; use crate::string_record::StringRecord; -/// Retrieve the underlying parts of a byte record. -#[inline] -pub fn as_parts(record: &mut ByteRecord) -> (&mut Vec, &mut Vec) { - // TODO(burntsushi): Use `pub(crate)` when it stabilizes. - // (&mut record.fields, &mut record.bounds.ends) - let inner = &mut *record.0; - (&mut inner.fields, &mut inner.bounds.ends) -} - -/// Set the number of fields in the given record record. -#[inline] -pub fn set_len(record: &mut ByteRecord, len: usize) { - // TODO(burntsushi): Use `pub(crate)` when it stabilizes. - record.0.bounds.len = len; -} - -/// Expand the capacity for storing fields. -#[inline] -pub fn expand_fields(record: &mut ByteRecord) { - // TODO(burntsushi): Use `pub(crate)` when it stabilizes. - let new_len = record.0.fields.len().checked_mul(2).unwrap(); - record.0.fields.resize(cmp::max(4, new_len), 0); -} - -/// Expand the capacity for storing field ending positions. -#[inline] -pub fn expand_ends(record: &mut ByteRecord) { - // TODO(burntsushi): Use `pub(crate)` when it stabilizes. - record.0.bounds.expand(); -} - -/// Validate the given record as UTF-8. -/// -/// If it's not UTF-8, return an error. -#[inline] -pub fn validate(record: &ByteRecord) -> result::Result<(), Utf8Error> { - // TODO(burntsushi): Use `pub(crate)` when it stabilizes. - - // If the entire buffer is ASCII, then we have nothing to fear. - if record.0.fields[..record.0.bounds.end()].iter().all(|&b| b <= 0x7F) { - return Ok(()); - } - // Otherwise, we must check each field individually to ensure that - // it's valid UTF-8. - for (i, field) in record.iter().enumerate() { - if let Err(err) = str::from_utf8(field) { - return Err(new_utf8_error(i, err.valid_up_to())); - } - } - Ok(()) -} - -/// Compare the given byte record with the iterator of fields for equality. -pub fn eq(record: &ByteRecord, other: I) -> bool -where - I: IntoIterator, - T: AsRef<[u8]>, -{ - let mut it_record = record.iter(); - let mut it_other = other.into_iter(); - loop { - match (it_record.next(), it_other.next()) { - (None, None) => return true, - (None, Some(_)) | (Some(_), None) => return false, - (Some(x), Some(y)) => { - if x != y.as_ref() { - return false; - } - } - } - } -} - /// A single CSV record stored as raw bytes. /// /// A byte record permits reading or writing CSV rows that are not UTF-8. @@ -118,25 +45,25 @@ impl PartialEq for ByteRecord { impl> PartialEq> for ByteRecord { fn eq(&self, other: &Vec) -> bool { - eq(self, other) + self.iter_eq(other) } } impl<'a, T: AsRef<[u8]>> PartialEq> for &'a ByteRecord { fn eq(&self, other: &Vec) -> bool { - eq(self, other) + self.iter_eq(other) } } impl> PartialEq<[T]> for ByteRecord { fn eq(&self, other: &[T]) -> bool { - eq(self, other) + self.iter_eq(other) } } impl<'a, T: AsRef<[u8]>> PartialEq<[T]> for &'a ByteRecord { fn eq(&self, other: &[T]) -> bool { - eq(self, other) + self.iter_eq(other) } } @@ -488,7 +415,7 @@ impl ByteRecord { pub fn push_field(&mut self, field: &[u8]) { let (s, e) = (self.0.bounds.end(), self.0.bounds.end() + field.len()); while e > self.0.fields.len() { - expand_fields(self); + self.expand_fields(); } self.0.fields[s..e].copy_from_slice(field); self.0.bounds.add(e); @@ -593,6 +520,72 @@ impl ByteRecord { pub fn as_slice(&self) -> &[u8] { &self.0.fields[..self.0.bounds.end()] } + + /// Retrieve the underlying parts of a byte record. + #[inline] + pub(crate) fn as_parts(&mut self) -> (&mut Vec, &mut Vec) { + let inner = &mut *self.0; + (&mut inner.fields, &mut inner.bounds.ends) + } + + /// Set the number of fields in the given record record. + #[inline] + pub(crate) fn set_len(&mut self, len: usize) { + self.0.bounds.len = len; + } + + /// Expand the capacity for storing fields. + #[inline] + pub(crate) fn expand_fields(&mut self) { + let new_len = self.0.fields.len().checked_mul(2).unwrap(); + self.0.fields.resize(cmp::max(4, new_len), 0); + } + + /// Expand the capacity for storing field ending positions. + #[inline] + pub(crate) fn expand_ends(&mut self) { + self.0.bounds.expand(); + } + + /// Validate the given record as UTF-8. + /// + /// If it's not UTF-8, return an error. + #[inline] + pub(crate) fn validate(&self) -> result::Result<(), Utf8Error> { + // If the entire buffer is ASCII, then we have nothing to fear. + if self.0.fields[..self.0.bounds.end()].iter().all(|&b| b <= 0x7F) { + return Ok(()); + } + // Otherwise, we must check each field individually to ensure that + // it's valid UTF-8. + for (i, field) in self.iter().enumerate() { + if let Err(err) = str::from_utf8(field) { + return Err(new_utf8_error(i, err.valid_up_to())); + } + } + Ok(()) + } + + /// Compare the given byte record with the iterator of fields for equality. + pub(crate) fn iter_eq(&self, other: I) -> bool + where + I: IntoIterator, + T: AsRef<[u8]>, + { + let mut it_record = self.iter(); + let mut it_other = other.into_iter(); + loop { + match (it_record.next(), it_other.next()) { + (None, None) => return true, + (None, Some(_)) | (Some(_), None) => return false, + (Some(x), Some(y)) => { + if x != y.as_ref() { + return false; + } + } + } + } + } } /// A position in CSV data. diff --git a/src/deserializer.rs b/src/deserializer.rs index 2f180bd..47f5c5f 100644 --- a/src/deserializer.rs +++ b/src/deserializer.rs @@ -12,7 +12,7 @@ use serde::de::{ }; use crate::byte_record::{ByteRecord, ByteRecordIter}; -use crate::error::{new_error, Error, ErrorKind}; +use crate::error::{Error, ErrorKind}; use crate::string_record::{StringRecord, StringRecordIter}; use self::DeserializeErrorKind as DEK; @@ -27,7 +27,7 @@ pub fn deserialize_string_record<'de, D: Deserialize<'de>>( field: 0, }); D::deserialize(&mut deser).map_err(|err| { - new_error(ErrorKind::Deserialize { + Error::new(ErrorKind::Deserialize { pos: record.position().map(Clone::clone), err: err, }) @@ -44,7 +44,7 @@ pub fn deserialize_byte_record<'de, D: Deserialize<'de>>( field: 0, }); D::deserialize(&mut deser).map_err(|err| { - new_error(ErrorKind::Deserialize { + Error::new(ErrorKind::Deserialize { pos: record.position().map(Clone::clone), err: err, }) diff --git a/src/error.rs b/src/error.rs index 426b37e..e6a3b17 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,17 +2,10 @@ use std::error::Error as StdError; use std::fmt; use std::io; use std::result; -use std::str; use crate::byte_record::{ByteRecord, Position}; use crate::deserializer::DeserializeError; -/// A crate private constructor for `Error`. -pub fn new_error(kind: ErrorKind) -> Error { - // TODO(burntsushi): Use `pub(crate)` when it stabilizes. - Error(Box::new(kind)) -} - /// A type alias for `Result`. pub type Result = result::Result; @@ -28,6 +21,11 @@ pub type Result = result::Result; pub struct Error(Box); impl Error { + /// A crate private constructor for `Error`. + pub(crate) fn new(kind: ErrorKind) -> Error { + Error(Box::new(kind)) + } + /// Return the specific type of this error. pub fn kind(&self) -> &ErrorKind { &self.0 @@ -103,7 +101,7 @@ pub enum ErrorKind { impl From for Error { fn from(err: io::Error) -> Error { - new_error(ErrorKind::Io(err)) + Error::new(ErrorKind::Io(err)) } } @@ -114,21 +112,7 @@ impl From for io::Error { } impl StdError for Error { - fn description(&self) -> &str { - match *self.0 { - ErrorKind::Io(ref err) => err.description(), - ErrorKind::Utf8 { ref err, .. } => err.description(), - ErrorKind::UnequalLengths { .. } => { - "record of different length found" - } - ErrorKind::Seek => "headers unavailable on seeked CSV reader", - ErrorKind::Serialize(ref err) => err, - ErrorKind::Deserialize { ref err, .. } => err.description(), - _ => unreachable!(), - } - } - - fn cause(&self) -> Option<&dyn StdError> { + fn source(&self) -> Option<&(dyn StdError + 'static)> { match *self.0 { ErrorKind::Io(ref err) => Some(err), ErrorKind::Utf8 { ref err, .. } => Some(err), @@ -218,12 +202,12 @@ pub struct FromUtf8Error { err: Utf8Error, } -/// Create a new FromUtf8Error. -pub fn new_from_utf8_error(rec: ByteRecord, err: Utf8Error) -> FromUtf8Error { - FromUtf8Error { record: rec, err: err } -} - impl FromUtf8Error { + /// Create a new FromUtf8Error. + pub(crate) fn new(rec: ByteRecord, err: Utf8Error) -> FromUtf8Error { + FromUtf8Error { record: rec, err: err } + } + /// Access the underlying `ByteRecord` that failed UTF-8 validation. pub fn into_byte_record(self) -> ByteRecord { self.record @@ -242,10 +226,7 @@ impl fmt::Display for FromUtf8Error { } impl StdError for FromUtf8Error { - fn description(&self) -> &str { - self.err.description() - } - fn cause(&self) -> Option<&dyn StdError> { + fn source(&self) -> Option<&(dyn StdError + 'static)> { Some(&self.err) } } @@ -281,11 +262,7 @@ impl Utf8Error { } } -impl StdError for Utf8Error { - fn description(&self) -> &str { - "invalid utf-8 in CSV record" - } -} +impl StdError for Utf8Error {} impl fmt::Display for Utf8Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -309,15 +286,15 @@ pub struct IntoInnerError { err: io::Error, } -/// Creates a new `IntoInnerError`. -/// -/// (This is a visibility hack. It's public in this module, but not in the -/// crate.) -pub fn new_into_inner_error(wtr: W, err: io::Error) -> IntoInnerError { - IntoInnerError { wtr: wtr, err: err } -} - impl IntoInnerError { + /// Creates a new `IntoInnerError`. + /// + /// (This is a visibility hack. It's public in this module, but not in the + /// crate.) + pub(crate) fn new(wtr: W, err: io::Error) -> IntoInnerError { + IntoInnerError { wtr: wtr, err: err } + } + /// Returns the error which caused the call to `into_inner` to fail. /// /// This error was returned when attempting to flush the internal buffer. @@ -334,14 +311,9 @@ impl IntoInnerError { } } -impl StdError for IntoInnerError { - fn description(&self) -> &str { - self.err.description() - } - - #[allow(deprecated)] - fn cause(&self) -> Option<&dyn StdError> { - self.err.cause() +impl StdError for IntoInnerError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.err.source() } } diff --git a/src/reader.rs b/src/reader.rs index 9697d41..dce47b4 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -7,9 +7,9 @@ use std::result; use csv_core::{Reader as CoreReader, ReaderBuilder as CoreReaderBuilder}; use serde::de::DeserializeOwned; -use crate::byte_record::{self, ByteRecord, Position}; -use crate::error::{new_error, ErrorKind, Result, Utf8Error}; -use crate::string_record::{self, StringRecord}; +use crate::byte_record::{ByteRecord, Position}; +use crate::error::{Error, ErrorKind, Result, Utf8Error}; +use crate::string_record::StringRecord; use crate::{Terminator, Trim}; /// Builds a CSV reader with various configuration knobs. @@ -1304,7 +1304,7 @@ impl Reader { let headers = self.state.headers.as_ref().unwrap(); match headers.string_record { Ok(ref record) => Ok(record), - Err(ref err) => Err(new_error(ErrorKind::Utf8 { + Err(ref err) => Err(Error::new(ErrorKind::Utf8 { pos: headers.byte_record.position().map(Clone::clone), err: err.clone(), })), @@ -1503,7 +1503,7 @@ impl Reader { /// } /// ``` pub fn read_record(&mut self, record: &mut StringRecord) -> Result { - let result = string_record::read(self, record); + let result = record.read(self); // We need to trim again because trimming string records includes // Unicode whitespace. (ByteRecord trimming only includes ASCII // whitespace.) @@ -1605,7 +1605,7 @@ impl Reader { loop { let (res, nin, nout, nend) = { let input = self.rdr.fill_buf()?; - let (fields, ends) = byte_record::as_parts(record); + let (fields, ends) = record.as_parts(); self.core.read_record( input, &mut fields[outlen..], @@ -1623,15 +1623,15 @@ impl Reader { match res { InputEmpty => continue, OutputFull => { - byte_record::expand_fields(record); + record.expand_fields(); continue; } OutputEndsFull => { - byte_record::expand_ends(record); + record.expand_ends(); continue; } Record => { - byte_record::set_len(record, endlen); + record.set_len(endlen); self.state.add_record(record)?; return Ok(true); } @@ -1860,7 +1860,7 @@ impl ReaderState { None => self.first_field_count = Some(record.len() as u64), Some(expected) => { if record.len() as u64 != expected { - return Err(new_error(ErrorKind::UnequalLengths { + return Err(Error::new(ErrorKind::UnequalLengths { pos: record.position().map(Clone::clone), expected_len: expected, len: record.len() as u64, diff --git a/src/serializer.rs b/src/serializer.rs index 960e7a7..ed8c367 100644 --- a/src/serializer.rs +++ b/src/serializer.rs @@ -10,7 +10,7 @@ use serde::ser::{ SerializeTupleStruct, SerializeTupleVariant, Serializer, }; -use crate::error::{new_error, Error, ErrorKind}; +use crate::error::{Error, ErrorKind}; use crate::writer::Writer; /// Serialize the given value to the given writer, and return an error if @@ -342,7 +342,7 @@ impl<'a, 'w, W: io::Write> SerializeStructVariant for &'a mut SeRecord<'w, W> { impl SerdeError for Error { fn custom(msg: T) -> Error { - new_error(ErrorKind::Serialize(msg.to_string())) + Error::new(ErrorKind::Serialize(msg.to_string())) } } diff --git a/src/string_record.rs b/src/string_record.rs index 86af88c..5cadff1 100644 --- a/src/string_record.rs +++ b/src/string_record.rs @@ -7,50 +7,11 @@ use std::str; use serde::de::Deserialize; -use crate::byte_record::{self, ByteRecord, ByteRecordIter, Position}; +use crate::byte_record::{ByteRecord, ByteRecordIter, Position}; use crate::deserializer::deserialize_string_record; -use crate::error::{ - new_error, new_from_utf8_error, ErrorKind, FromUtf8Error, Result, -}; +use crate::error::{Error, ErrorKind, FromUtf8Error, Result}; use crate::reader::Reader; -/// A safe function for reading CSV data into a `StringRecord`. -/// -/// This relies on the internal representation of `StringRecord`. -#[inline(always)] -pub fn read( - rdr: &mut Reader, - record: &mut StringRecord, -) -> Result { - // TODO(burntsushi): Define this as a method using `pub(crate)` when that - // stabilizes. - - // SAFETY: Note that despite the absence of `unsafe` in this function, this - // code is critical to upholding the safety of other `unsafe` blocks in - // this module. Namely, after calling `read_byte_record`, it is possible - // for `record` to contain invalid UTF-8. We check for this in the - // `validate` method, and if it does have invalid UTF-8, we clear the - // record. (It is bad for `record` to contain invalid UTF-8 because other - // accessor methods, like `get`, assume that every field is valid UTF-8.) - let pos = rdr.position().clone(); - let read_res = rdr.read_byte_record(&mut record.0); - let utf8_res = match byte_record::validate(&record.0) { - Ok(()) => Ok(()), - Err(err) => { - // If this record isn't valid UTF-8, then completely wipe it. - record.0.clear(); - Err(err) - } - }; - match (read_res, utf8_res) { - (Err(err), _) => Err(err), - (Ok(_), Err(err)) => { - Err(new_error(ErrorKind::Utf8 { pos: Some(pos), err: err })) - } - (Ok(eof), Ok(())) => Ok(eof), - } -} - /// A single CSV record stored as valid UTF-8 bytes. /// /// A string record permits reading or writing CSV rows that are valid UTF-8. @@ -76,31 +37,31 @@ pub struct StringRecord(ByteRecord); impl PartialEq for StringRecord { fn eq(&self, other: &StringRecord) -> bool { - byte_record::eq(&self.0, &other.0) + self.0.iter_eq(&other.0) } } impl> PartialEq> for StringRecord { fn eq(&self, other: &Vec) -> bool { - byte_record::eq(&self.0, other) + self.0.iter_eq(other) } } impl<'a, T: AsRef<[u8]>> PartialEq> for &'a StringRecord { fn eq(&self, other: &Vec) -> bool { - byte_record::eq(&self.0, other) + self.0.iter_eq(other) } } impl> PartialEq<[T]> for StringRecord { fn eq(&self, other: &[T]) -> bool { - byte_record::eq(&self.0, other) + self.0.iter_eq(other) } } impl<'a, T: AsRef<[u8]>> PartialEq<[T]> for &'a StringRecord { fn eq(&self, other: &[T]) -> bool { - byte_record::eq(&self.0, other) + self.0.iter_eq(other) } } @@ -193,9 +154,9 @@ impl StringRecord { pub fn from_byte_record( record: ByteRecord, ) -> result::Result { - match byte_record::validate(&record) { + match record.validate() { Ok(()) => Ok(StringRecord(record)), - Err(err) => Err(new_from_utf8_error(record, err)), + Err(err) => Err(FromUtf8Error::new(record, err)), } } @@ -231,7 +192,7 @@ impl StringRecord { #[inline] pub fn from_byte_record_lossy(record: ByteRecord) -> StringRecord { // If the record is valid UTF-8, then take the easy path. - if let Ok(()) = byte_record::validate(&record) { + if let Ok(()) = record.validate() { return StringRecord(record); } // TODO: We can be faster here. Not sure if it's worth it. @@ -645,6 +606,41 @@ impl StringRecord { pub fn into_byte_record(self) -> ByteRecord { self.0 } + + /// A safe function for reading CSV data into a `StringRecord`. + /// + /// This relies on the internal representation of `StringRecord`. + #[inline(always)] + pub(crate) fn read( + &mut self, + rdr: &mut Reader, + ) -> Result { + // SAFETY: Note that despite the absence of `unsafe` in this function, + // this code is critical to upholding the safety of other `unsafe` + // blocks in this module. Namely, after calling `read_byte_record`, + // it is possible for `record` to contain invalid UTF-8. We check for + // this in the `validate` method, and if it does have invalid UTF-8, we + // clear the record. (It is bad for `record` to contain invalid UTF-8 + // because other accessor methods, like `get`, assume that every field + // is valid UTF-8.) + let pos = rdr.position().clone(); + let read_res = rdr.read_byte_record(&mut self.0); + let utf8_res = match self.0.validate() { + Ok(()) => Ok(()), + Err(err) => { + // If this record isn't valid UTF-8, then completely wipe it. + self.0.clear(); + Err(err) + } + }; + match (read_res, utf8_res) { + (Err(err), _) => Err(err), + (Ok(_), Err(err)) => { + Err(Error::new(ErrorKind::Utf8 { pos: Some(pos), err: err })) + } + (Ok(eof), Ok(())) => Ok(eof), + } + } } impl ops::Index for StringRecord { diff --git a/src/writer.rs b/src/writer.rs index a4649df..5d72469 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -10,9 +10,7 @@ use csv_core::{ use serde::Serialize; use crate::byte_record::ByteRecord; -use crate::error::{ - new_error, new_into_inner_error, ErrorKind, IntoInnerError, Result, -}; +use crate::error::{Error, ErrorKind, IntoInnerError, Result}; use crate::serializer::{serialize, serialize_header}; use crate::{QuoteStyle, Terminator}; @@ -1073,7 +1071,7 @@ impl Writer { ) -> result::Result>> { match self.flush() { Ok(()) => Ok(self.wtr.take().unwrap()), - Err(err) => Err(new_into_inner_error(self, err)), + Err(err) => Err(IntoInnerError::new(self, err)), } } @@ -1134,7 +1132,7 @@ impl Writer { Some(self.state.fields_written); } Some(expected) if expected != self.state.fields_written => { - return Err(new_error(ErrorKind::UnequalLengths { + return Err(Error::new(ErrorKind::UnequalLengths { pos: None, expected_len: expected, len: self.state.fields_written,