Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-3729: [C++][Parquet] Use logical annotations in Arrow Parquet reader/writer #4421

Closed
wants to merge 9 commits into from
Next Next commit
Use logical annotations in Arrow Parquet reader/writer
  • Loading branch information
tpboudreau committed Jun 18, 2019
commit f9379435244a3c6e3c79f763e084c13efaad4a9e
159 changes: 115 additions & 44 deletions cpp/src/parquet/arrow/arrow-reader-writer-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,58 +87,84 @@ static constexpr int LARGE_SIZE = 10000;

static constexpr uint32_t kDefaultSeed = 0;

LogicalType::type get_logical_type(const ::DataType& type) {
std::shared_ptr<const LogicalAnnotation> get_logical_annotation(const ::DataType& type,
int32_t precision,
int32_t scale) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the precision/scale params here should come from the Arrow type object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was all was pretty sloppy; thanks for pointing it out. I've cleaned it up.

switch (type.id()) {
case ArrowId::UINT8:
return LogicalType::UINT_8;
return LogicalAnnotation::Int(8, false);
case ArrowId::INT8:
return LogicalType::INT_8;
return LogicalAnnotation::Int(8, true);
case ArrowId::UINT16:
return LogicalType::UINT_16;
return LogicalAnnotation::Int(16, false);
case ArrowId::INT16:
return LogicalType::INT_16;
return LogicalAnnotation::Int(16, true);
case ArrowId::UINT32:
return LogicalType::UINT_32;
return LogicalAnnotation::Int(32, false);
case ArrowId::INT32:
return LogicalType::INT_32;
return LogicalAnnotation::Int(32, true);
case ArrowId::UINT64:
return LogicalType::UINT_64;
return LogicalAnnotation::Int(64, false);
case ArrowId::INT64:
return LogicalType::INT_64;
return LogicalAnnotation::Int(64, true);
case ArrowId::STRING:
return LogicalType::UTF8;
return LogicalAnnotation::String();
case ArrowId::DATE32:
return LogicalType::DATE;
return LogicalAnnotation::Date();
case ArrowId::DATE64:
return LogicalType::DATE;
return LogicalAnnotation::Date();
case ArrowId::TIMESTAMP: {
const auto& ts_type = static_cast<const ::arrow::TimestampType&>(type);
const bool adjusted_to_utc = (ts_type.timezone() == "UTC");
switch (ts_type.unit()) {
case TimeUnit::MILLI:
return LogicalType::TIMESTAMP_MILLIS;
return LogicalAnnotation::Timestamp(adjusted_to_utc,
LogicalAnnotation::TimeUnit::MILLIS);
case TimeUnit::MICRO:
return LogicalType::TIMESTAMP_MICROS;
return LogicalAnnotation::Timestamp(adjusted_to_utc,
LogicalAnnotation::TimeUnit::MICROS);
case TimeUnit::NANO:
return LogicalAnnotation::Timestamp(adjusted_to_utc,
LogicalAnnotation::TimeUnit::NANOS);
default:
DCHECK(false) << "Only MILLI and MICRO units supported for Arrow timestamps "
"with Parquet.";
DCHECK(false)
<< "Only MILLI, MICRO, and NANO units supported for Arrow TIMESTAMP.";
}
break;
}
case ArrowId::TIME32:
return LogicalType::TIME_MILLIS;
case ArrowId::TIME64:
return LogicalType::TIME_MICROS;
return LogicalAnnotation::Time(false, LogicalAnnotation::TimeUnit::MILLIS);
case ArrowId::TIME64: {
const auto& tm_type = static_cast<const ::arrow::TimeType&>(type);
switch (tm_type.unit()) {
case TimeUnit::MICRO:
return LogicalAnnotation::Time(false, LogicalAnnotation::TimeUnit::MICROS);
case TimeUnit::NANO:
return LogicalAnnotation::Time(false, LogicalAnnotation::TimeUnit::NANOS);
default:
DCHECK(false) << "Only MICRO and NANO units supported for Arrow TIME64.";
}
break;
}
case ArrowId::DICTIONARY: {
const ::arrow::DictionaryType& dict_type =
static_cast<const ::arrow::DictionaryType&>(type);
return get_logical_type(*dict_type.value_type());
const ::DataType& ty = *dict_type.value_type();
int32_t pr = -1;
int32_t sc = -1;
if (ty.id() == ArrowId::DECIMAL) {
const auto& dt = static_cast<const ::arrow::Decimal128Type&>(ty);
pr = dt.precision();
sc = dt.scale();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You use the type's precision/scale here but not in the next block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned up now.

}
return get_logical_annotation(ty, pr, sc);
}
case ArrowId::DECIMAL:
return LogicalType::DECIMAL;
return LogicalAnnotation::Decimal(precision, scale);
default:
break;
}
return LogicalType::NONE;
return LogicalAnnotation::None();
}

ParquetType::type get_physical_type(const ::DataType& type) {
Expand Down Expand Up @@ -383,6 +409,8 @@ void CheckSimpleRoundtrip(const std::shared_ptr<Table>& table, int64_t row_group
std::shared_ptr<Table> result;
DoSimpleRoundtrip(table, false /* use_threads */, row_group_size, {}, &result,
arrow_properties);
ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*table->schema(), *result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*table, *result, false));
}

Expand Down Expand Up @@ -424,8 +452,9 @@ static std::shared_ptr<GroupNode> MakeSimpleSchema(const ::DataType& type,
default:
break;
}
auto pnode = PrimitiveNode::Make("column1", repetition, get_physical_type(type),
get_logical_type(type), byte_width, precision, scale);
auto pnode = PrimitiveNode::Make("column1", repetition,
get_logical_annotation(type, precision, scale),
get_physical_type(type), byte_width);
NodePtr node_ =
GroupNode::Make("schema", Repetition::REQUIRED, std::vector<NodePtr>({pnode}));
return std::static_pointer_cast<GroupNode>(node_);
Expand Down Expand Up @@ -1195,7 +1224,7 @@ TYPED_TEST(TestPrimitiveParquetIO, SingleColumnRequiredChunkedTableRead) {
ASSERT_NO_FATAL_FAILURE(this->CheckSingleColumnRequiredTableRead(4));
}

void MakeDateTimeTypesTable(std::shared_ptr<Table>* out, bool nanos_as_micros = false) {
void MakeDateTimeTypesTable(std::shared_ptr<Table>* out) {
using ::arrow::ArrayFromVector;

std::vector<bool> is_valid = {true, true, true, false, true, true};
Expand All @@ -1204,12 +1233,13 @@ void MakeDateTimeTypesTable(std::shared_ptr<Table>* out, bool nanos_as_micros =
auto f0 = field("f0", ::arrow::date32());
auto f1 = field("f1", ::arrow::timestamp(TimeUnit::MILLI));
auto f2 = field("f2", ::arrow::timestamp(TimeUnit::MICRO));
auto f3_unit = nanos_as_micros ? TimeUnit::MICRO : TimeUnit::NANO;
auto f3 = field("f3", ::arrow::timestamp(f3_unit));
auto f3 = field("f3", ::arrow::timestamp(TimeUnit::NANO));
auto f4 = field("f4", ::arrow::time32(TimeUnit::MILLI));
auto f5 = field("f5", ::arrow::time64(TimeUnit::MICRO));
auto f6 = field("f6", ::arrow::time64(TimeUnit::NANO));

std::shared_ptr<::arrow::Schema> schema(new ::arrow::Schema({f0, f1, f2, f3, f4, f5}));
std::shared_ptr<::arrow::Schema> schema(
new ::arrow::Schema({f0, f1, f2, f3, f4, f5, f6}));

std::vector<int32_t> t32_values = {1489269000, 1489270000, 1489271000,
1489272000, 1489272000, 1489273000};
Expand All @@ -1220,34 +1250,37 @@ void MakeDateTimeTypesTable(std::shared_ptr<Table>* out, bool nanos_as_micros =
std::vector<int64_t> t64_ms_values = {1489269, 1489270, 1489271,
1489272, 1489272, 1489273};

std::shared_ptr<Array> a0, a1, a2, a3, a4, a5;
std::shared_ptr<Array> a0, a1, a2, a3, a4, a5, a6;
ArrayFromVector<::arrow::Date32Type, int32_t>(f0->type(), is_valid, t32_values, &a0);
ArrayFromVector<::arrow::TimestampType, int64_t>(f1->type(), is_valid, t64_ms_values,
&a1);
ArrayFromVector<::arrow::TimestampType, int64_t>(f2->type(), is_valid, t64_us_values,
&a2);
auto f3_data = nanos_as_micros ? t64_us_values : t64_ns_values;
ArrayFromVector<::arrow::TimestampType, int64_t>(f3->type(), is_valid, f3_data, &a3);
ArrayFromVector<::arrow::TimestampType, int64_t>(f3->type(), is_valid, t64_ns_values,
&a3);
ArrayFromVector<::arrow::Time32Type, int32_t>(f4->type(), is_valid, t32_values, &a4);
ArrayFromVector<::arrow::Time64Type, int64_t>(f5->type(), is_valid, t64_us_values, &a5);
ArrayFromVector<::arrow::Time64Type, int64_t>(f6->type(), is_valid, t64_ns_values, &a6);

std::vector<std::shared_ptr<::arrow::Column>> columns = {
std::make_shared<Column>("f0", a0), std::make_shared<Column>("f1", a1),
std::make_shared<Column>("f2", a2), std::make_shared<Column>("f3", a3),
std::make_shared<Column>("f4", a4), std::make_shared<Column>("f5", a5)};
std::make_shared<Column>("f4", a4), std::make_shared<Column>("f5", a5),
std::make_shared<Column>("f6", a6)};

*out = Table::Make(schema, columns);
}

TEST(TestArrowReadWrite, DateTimeTypes) {
std::shared_ptr<Table> table, result;
MakeDateTimeTypesTable(&table);

// Cast nanaoseconds to microseconds and use INT64 physical type
MakeDateTimeTypesTable(&table);
ASSERT_NO_FATAL_FAILURE(
DoSimpleRoundtrip(table, false /* use_threads */, table->num_rows(), {}, &result));
MakeDateTimeTypesTable(&table, true);

MakeDateTimeTypesTable(&table);
ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*table->schema(), *result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*table, *result));
}

Expand Down Expand Up @@ -1300,6 +1333,8 @@ TEST(TestArrowReadWrite, UseDeprecatedInt96) {
input, false /* use_threads */, input->num_rows(), {}, &result,
ArrowWriterProperties::Builder().enable_deprecated_int96_timestamps()->build()));

ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*ex_result->schema(), *result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_result, *result));

// Ensure enable_deprecated_int96_timestamps as precedence over
Expand All @@ -1311,14 +1346,15 @@ TEST(TestArrowReadWrite, UseDeprecatedInt96) {
->coerce_timestamps(TimeUnit::MILLI)
->build()));

ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*ex_result->schema(), *result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_result, *result));
}

TEST(TestArrowReadWrite, CoerceTimestamps) {
using ::arrow::ArrayFromVector;
using ::arrow::field;

// PARQUET-1078, coerce Arrow timestamps to either TIMESTAMP_MILLIS or TIMESTAMP_MICROS
std::vector<bool> is_valid = {true, true, true, false, true, true};

auto t_s = ::arrow::timestamp(TimeUnit::SECOND);
Expand Down Expand Up @@ -1363,6 +1399,8 @@ TEST(TestArrowReadWrite, CoerceTimestamps) {
ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(
input, false /* use_threads */, input->num_rows(), {}, &milli_result,
ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MILLI)->build()));
ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*ex_milli_result->schema(), *milli_result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_milli_result, *milli_result));

// Result when coercing to microseconds
Expand All @@ -1378,7 +1416,26 @@ TEST(TestArrowReadWrite, CoerceTimestamps) {
ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(
input, false /* use_threads */, input->num_rows(), {}, &micro_result,
ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MICRO)->build()));
ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*ex_micro_result->schema(), *micro_result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_micro_result, *micro_result));

// Result when coercing to nanoseconds
auto s4 = std::shared_ptr<::arrow::Schema>(
new ::arrow::Schema({field("f_s", t_ns), field("f_ms", t_ns), field("f_us", t_ns),
field("f_ns", t_ns)}));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also use ::arrow::schema factory function for tighter code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed to it here and in a few other nearby places.

auto ex_nano_result = Table::Make(
s4,
{std::make_shared<Column>("f_s", a_ns), std::make_shared<Column>("f_ms", a_ns),
std::make_shared<Column>("f_us", a_ns), std::make_shared<Column>("f_ns", a_ns)});

std::shared_ptr<Table> nano_result;
ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(
input, false /* use_threads */, input->num_rows(), {}, &nano_result,
ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::NANO)->build()));
ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*ex_nano_result->schema(), *nano_result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_nano_result, *nano_result));
}

TEST(TestArrowReadWrite, CoerceTimestampsLosePrecision) {
Expand Down Expand Up @@ -1439,25 +1496,37 @@ TEST(TestArrowReadWrite, CoerceTimestampsLosePrecision) {
ASSERT_RAISES(Invalid, WriteTable(*t4, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), coerce_millis));

// OK to lose precision if we explicitly allow it
auto allow_truncation = (ArrowWriterProperties::Builder()
.coerce_timestamps(TimeUnit::MILLI)
->allow_truncated_timestamps()
->build());
// OK to lose micros/nanos -> millis precision if we explicitly allow it
auto allow_truncation_to_millis = (ArrowWriterProperties::Builder()
.coerce_timestamps(TimeUnit::MILLI)
->allow_truncated_timestamps()
->build());
ASSERT_OK_NO_THROW(WriteTable(*t3, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), allow_truncation));
default_writer_properties(), allow_truncation_to_millis));
ASSERT_OK_NO_THROW(WriteTable(*t4, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), allow_truncation));
default_writer_properties(), allow_truncation_to_millis));

// OK to write micros to micros
// OK to write to micros
auto coerce_micros =
(ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MICRO)->build());
ASSERT_OK_NO_THROW(WriteTable(*t1, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), coerce_micros));
ASSERT_OK_NO_THROW(WriteTable(*t2, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), coerce_micros));
ASSERT_OK_NO_THROW(WriteTable(*t3, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), coerce_micros));

// Loss of precision
ASSERT_RAISES(Invalid, WriteTable(*t4, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), coerce_micros));

// OK to lose nanos -> micros precision if we explicitly allow it
auto allow_truncation_to_micros = (ArrowWriterProperties::Builder()
.coerce_timestamps(TimeUnit::MICRO)
->allow_truncated_timestamps()
->build());
ASSERT_OK_NO_THROW(WriteTable(*t4, ::arrow::default_memory_pool(), sink, 10,
default_writer_properties(), allow_truncation_to_micros));
}

TEST(TestArrowReadWrite, ConvertedDateTimeTypes) {
Expand Down Expand Up @@ -1515,6 +1584,8 @@ TEST(TestArrowReadWrite, ConvertedDateTimeTypes) {
ASSERT_NO_FATAL_FAILURE(
DoSimpleRoundtrip(table, false /* use_threads */, table->num_rows(), {}, &result));

ASSERT_NO_FATAL_FAILURE(
::arrow::AssertSchemaEqual(*ex_table->schema(), *result->schema()));
ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_table, *result));
}

Expand Down
Loading