Skip to content

Commit

Permalink
Improve Error Handling and Readibility for downcasting Float32Array
Browse files Browse the repository at this point in the history
…, `Float64Array`, `StringArray` (apache#4244)

* improve error messages while downcasting Int32Array

* improve error messages while downcasting float and string array

* refactor arrows as_string_array functions

* fmt and clippt beautify
  • Loading branch information
retikulum committed Nov 17, 2022
1 parent 5de9709 commit f4996b9
Show file tree
Hide file tree
Showing 26 changed files with 133 additions and 160 deletions.
9 changes: 5 additions & 4 deletions benchmarks/src/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayRef, Float64Array, StringArray};
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use std::fs;
Expand All @@ -25,7 +25,8 @@ use std::sync::Arc;
use std::time::Instant;

use datafusion::common::cast::{
as_date32_array, as_decimal128_array, as_int32_array, as_int64_array,
as_date32_array, as_decimal128_array, as_float64_array, as_int32_array,
as_int64_array, as_string_array,
};
use datafusion::common::ScalarValue;
use datafusion::logical_expr::Cast;
Expand Down Expand Up @@ -432,7 +433,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
ScalarValue::Int64(Some(array.value(row_index)))
}
DataType::Float64 => {
let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
let array = as_float64_array(column).unwrap();
ScalarValue::Float64(Some(array.value(row_index)))
}
DataType::Decimal128(p, s) => {
Expand All @@ -444,7 +445,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
ScalarValue::Date32(Some(array.value(row_index)))
}
DataType::Utf8 => {
let array = column.as_any().downcast_ref::<StringArray>().unwrap();
let array = as_string_array(column).unwrap();
ScalarValue::Utf8(Some(array.value(row_index).to_string()))
}
other => panic!("unexpected data type in benchmark: {}", other),
Expand Down
10 changes: 3 additions & 7 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
/// In this example we will declare a single-type, single return type UDAF that computes the geometric mean.
/// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean
use datafusion::arrow::{
array::ArrayRef, array::Float32Array, array::Float64Array, datatypes::DataType,
record_batch::RecordBatch,
array::ArrayRef, array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
};
use datafusion::from_slice::FromSlice;
use datafusion::logical_expr::AggregateState;
use datafusion::{error::Result, physical_plan::Accumulator};
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::create_udaf;
use std::sync::Arc;

Expand Down Expand Up @@ -187,11 +187,7 @@ async fn main() -> Result<()> {
let results = df.collect().await?;

// downcast the array to the expected type
let result = results[0]
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let result = as_float64_array(results[0].column(0))?;

// verify that the calculation is correct
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
Expand Down
11 changes: 3 additions & 8 deletions datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion::{
use datafusion::from_slice::FromSlice;
use datafusion::prelude::*;
use datafusion::{error::Result, physical_plan::functions::make_scalar_function};
use datafusion_common::cast::as_float64_array;
use std::sync::Arc;

// create local execution context with an in-memory table
Expand Down Expand Up @@ -70,14 +71,8 @@ async fn main() -> Result<()> {
assert_eq!(args.len(), 2);

// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = &args[0]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
let exponent = &args[1]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
let base = as_float64_array(&args[0]).expect("cast failed");
let exponent = as_float64_array(&args[1]).expect("cast failed");

// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());
Expand Down
39 changes: 38 additions & 1 deletion datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

use crate::DataFusionError;
use arrow::array::{
Array, Date32Array, Decimal128Array, Int32Array, Int64Array, StructArray,
Array, Date32Array, Decimal128Array, Float32Array, Float64Array, Int32Array,
Int64Array, StringArray, StructArray,
};

// Downcast ArrayRef to Date32Array
Expand Down Expand Up @@ -79,3 +80,39 @@ pub fn as_decimal128_array(
))
})
}

// Downcast ArrayRef to Float32Array
pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array, DataFusionError> {
array
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected a Float32Array, got: {}",
array.data_type()
))
})
}

// Downcast ArrayRef to Float64Array
pub fn as_float64_array(array: &dyn Array) -> Result<&Float64Array, DataFusionError> {
array
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected a Float64Array, got: {}",
array.data_type()
))
})
}

// Downcast ArrayRef to StringArray
pub fn as_string_array(array: &dyn Array) -> Result<&StringArray, DataFusionError> {
array.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected a StringArray, got: {}",
array.data_type()
))
})
}
3 changes: 2 additions & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2657,6 +2657,7 @@ mod tests {
use arrow::compute::kernels;
use arrow::datatypes::ArrowPrimitiveType;

use crate::cast::as_string_array;
use crate::from_slice::FromSlice;

use super::*;
Expand Down Expand Up @@ -3020,7 +3021,7 @@ mod tests {

let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap();
let array = as_dictionary_array::<Int32Type>(&array);
let values_array = as_string_array(array.values());
let values_array = as_string_array(array.values()).unwrap();

let values = array
.keys_iter()
Expand Down
18 changes: 4 additions & 14 deletions datafusion/core/src/datasource/file_format/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,8 @@ mod tests {
use crate::datasource::file_format::test_util::scan_format;
use crate::physical_plan::collect;
use crate::prelude::{SessionConfig, SessionContext};
use arrow::array::{
BinaryArray, BooleanArray, Float32Array, Float64Array, TimestampMicrosecondArray,
};
use datafusion_common::cast::as_int32_array;
use arrow::array::{BinaryArray, BooleanArray, TimestampMicrosecondArray};
use datafusion_common::cast::{as_float32_array, as_float64_array, as_int32_array};
use futures::StreamExt;

#[tokio::test]
Expand Down Expand Up @@ -279,11 +277,7 @@ mod tests {
assert_eq!(1, batches[0].num_columns());
assert_eq!(8, batches[0].num_rows());

let array = batches[0]
.column(0)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
let array = as_float32_array(batches[0].column(0))?;
let mut values: Vec<f32> = vec![];
for i in 0..batches[0].num_rows() {
values.push(array.value(i));
Expand All @@ -309,11 +303,7 @@ mod tests {
assert_eq!(1, batches[0].num_columns());
assert_eq!(8, batches[0].num_rows());

let array = batches[0]
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let array = as_float64_array(batches[0].column(0))?;
let mut values: Vec<f64> = vec![];
for i in 0..batches[0].num_rows() {
values.push(array.value(i));
Expand Down
9 changes: 2 additions & 7 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ impl FileFormat for CsvFormat {

#[cfg(test)]
mod tests {
use arrow::array::StringArray;

use super::super::test_util::scan_format;
use super::*;
use crate::physical_plan::collect;
use crate::prelude::{SessionConfig, SessionContext};
use datafusion_common::cast::as_string_array;
use futures::StreamExt;

#[tokio::test]
Expand Down Expand Up @@ -270,11 +269,7 @@ mod tests {
assert_eq!(1, batches[0].num_columns());
assert_eq!(100, batches[0].num_rows());

let array = batches[0]
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let array = as_string_array(batches[0].column(0))?;
let mut values: Vec<&str> = vec![];
for i in 0..5 {
values.push(array.value(i));
Expand Down
17 changes: 4 additions & 13 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,13 +586,12 @@ mod tests {
use crate::physical_plan::metrics::MetricValue;
use crate::prelude::{SessionConfig, SessionContext};
use arrow::array::{
Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
StringArray, TimestampNanosecondArray,
Array, ArrayRef, BinaryArray, BooleanArray, StringArray, TimestampNanosecondArray,
};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use bytes::Bytes;
use datafusion_common::cast::as_int32_array;
use datafusion_common::cast::{as_float32_array, as_float64_array, as_int32_array};
use datafusion_common::ScalarValue;
use futures::stream::BoxStream;
use futures::StreamExt;
Expand Down Expand Up @@ -1026,11 +1025,7 @@ mod tests {
assert_eq!(1, batches[0].num_columns());
assert_eq!(8, batches[0].num_rows());

let array = batches[0]
.column(0)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
let array = as_float32_array(batches[0].column(0))?;
let mut values: Vec<f32> = vec![];
for i in 0..batches[0].num_rows() {
values.push(array.value(i));
Expand All @@ -1056,11 +1051,7 @@ mod tests {
assert_eq!(1, batches[0].num_columns());
assert_eq!(8, batches[0].num_rows());

let array = batches[0]
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let array = as_float64_array(batches[0].column(0))?;
let mut values: Vec<f64> = vec![];
for i in 0..batches[0].num_rows() {
values.push(array.value(i));
Expand Down
12 changes: 4 additions & 8 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use std::sync::Arc;

use arrow::{
array::{
Array, ArrayBuilder, ArrayRef, Date64Array, Date64Builder, StringArray,
StringBuilder, UInt64Array, UInt64Builder,
Array, ArrayBuilder, ArrayRef, Date64Array, Date64Builder, StringBuilder,
UInt64Array, UInt64Builder,
},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
Expand All @@ -38,7 +38,7 @@ use crate::{

use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use datafusion_common::{Column, DataFusionError};
use datafusion_common::{cast::as_string_array, Column, DataFusionError};
use datafusion_expr::{
expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion},
Expr, Volatility,
Expand Down Expand Up @@ -299,11 +299,7 @@ fn batches_to_paths(batches: &[RecordBatch]) -> Result<Vec<PartitionedFile>> {
batches
.iter()
.flat_map(|batch| {
let key_array = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let key_array = as_string_array(batch.column(0)).unwrap();
let length_array = batch
.column(1)
.as_any()
Expand Down
25 changes: 16 additions & 9 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ use ahash::RandomState;

use arrow::{
array::{
as_dictionary_array, as_string_array, ArrayData, ArrayRef, BooleanArray,
Date32Array, Date64Array, Decimal128Array, DictionaryArray, LargeStringArray,
PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampSecondArray, UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder,
UInt64Builder,
as_dictionary_array, ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array,
Decimal128Array, DictionaryArray, LargeStringArray, PrimitiveArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray,
UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder,
},
compute,
datatypes::{
Expand All @@ -53,6 +52,8 @@ use arrow::array::{
UInt8Array,
};

use datafusion_common::cast::as_string_array;

use hashbrown::raw::RawTable;

use crate::physical_plan::{
Expand Down Expand Up @@ -1122,9 +1123,12 @@ macro_rules! equal_rows_elem_with_string_dict {
.to_usize()
.expect("Can not convert index to usize in dictionary");

(as_string_array(left_array.values()), Some(values_index))
(
as_string_array(left_array.values()).unwrap(),
Some(values_index),
)
} else {
(as_string_array(left_array.values()), None)
(as_string_array(left_array.values()).unwrap(), None)
}
};
let (right_values, right_values_index) = {
Expand All @@ -1135,9 +1139,12 @@ macro_rules! equal_rows_elem_with_string_dict {
.to_usize()
.expect("Can not convert index to usize in dictionary");

(as_string_array(right_array.values()), Some(values_index))
(
as_string_array(right_array.values()).unwrap(),
Some(values_index),
)
} else {
(as_string_array(right_array.values()), None)
(as_string_array(right_array.values()).unwrap(), None)
}
};

Expand Down
6 changes: 2 additions & 4 deletions datafusion/core/src/physical_plan/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ mod tests {
array::{ArrayRef, StringArray},
error::ArrowError,
};
use datafusion_common::cast::as_string_array;
use futures::FutureExt;
use std::collections::HashSet;

Expand Down Expand Up @@ -962,10 +963,7 @@ mod tests {
.iter()
.flat_map(|batch| {
assert_eq!(batch.columns().len(), 1);
let string_array = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
let string_array = as_string_array(batch.column(0))
.expect("Unexpected type for repartitoned batch");

string_array
Expand Down
5 changes: 3 additions & 2 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ mod tests {
use arrow::array::*;
use arrow::compute::SortOptions;
use arrow::datatypes::*;
use datafusion_common::cast::as_string_array;
use futures::FutureExt;
use std::collections::{BTreeMap, HashMap};

Expand Down Expand Up @@ -990,7 +991,7 @@ mod tests {

let columns = result[0].columns();

let c1 = as_string_array(&columns[0]);
let c1 = as_string_array(&columns[0])?;
assert_eq!(c1.value(0), "a");
assert_eq!(c1.value(c1.len() - 1), "e");

Expand Down Expand Up @@ -1062,7 +1063,7 @@ mod tests {

let columns = result[0].columns();

let c1 = as_string_array(&columns[0]);
let c1 = as_string_array(&columns[0])?;
assert_eq!(c1.value(0), "a");
assert_eq!(c1.value(c1.len() - 1), "e");

Expand Down
Loading

0 comments on commit f4996b9

Please sign in to comment.