Skip to content

Commit

Permalink
add surrogate to locations from arrow
Browse files Browse the repository at this point in the history
  • Loading branch information
hugopendlebury committed Feb 22, 2024
1 parent 0ee2d7d commit 6b2b57b
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 21 deletions.
10 changes: 10 additions & 0 deletions src/gribhelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,14 @@ std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> getConversion

return fieldTypes;

}

std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> getLocationFieldDefinitions() {

std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> fieldTypes;
fieldTypes.emplace(make_pair("lat", arrow::float64()));
fieldTypes.emplace(make_pair("lon", arrow::float64()));

return fieldTypes;

}
3 changes: 2 additions & 1 deletion src/gribhelpers.hpp
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> getConversionFieldDefinitions() ;
std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> getConversionFieldDefinitions() ;
std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> getLocationFieldDefinitions() ;
67 changes: 49 additions & 18 deletions src/gribreader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,58 @@ GribReader::GribReader(string filepath) : filepath(filepath) {

GribReader GribReader::withLocations(std::shared_ptr<arrow::Table> locations) {
//TODO - add some validation
validateLocationFields(locations, " passed conversions via arrow");
locations = GribReader::enrichLocationsWithSurrogateKey(locations);
locations = castTableFields(locations, " passed conversions via arrow", getLocationFieldDefinitions());
this->shared_locations = locations;
return *this;
}

void GribReader::validateLocationFields(std::shared_ptr<arrow::Table> locations, std::string table_name) {
auto table = locations.get();
auto columns = table->ColumnNames();
std::set<std::string> columnsSet(std::make_move_iterator(columns.begin()),
std::make_move_iterator(columns.end()));


std::vector<std::string> required_columns = {"lat",
"lon"};

for (auto col : required_columns) {
const bool is_in = columnsSet.find(col) != columnsSet.end();
if (!is_in){
std::string errDetail = "Column " + col + " is not present in schema of table " + table_name;
throw InvalidSchemaException(errDetail);
}
}
}

std::shared_ptr<arrow::Table> GribReader::enrichLocationsWithSurrogateKey(std::shared_ptr<arrow::Table> locations) {
//Append an additional column to the table called surrogate key
auto numberOfRows = locations.get()->num_rows();
auto surrogate_columns = createSurrogateKeyCol(numberOfRows);
auto skField = arrow::field("surrogate_key", arrow::uint16());
auto chunkedArray = std::make_shared<arrow::ChunkedArray>(arrow::ChunkedArray(surrogate_columns.ValueOrDie()));
auto locationsResult = locations.get()->AddColumn(0, skField, chunkedArray);
if (!locationsResult.ok()) {
std::string errDetails = "Error adding surrogate key "
" " + locationsResult.status().message();
throw UnableToCreateArrowTableReaderException(errDetails);

}
return locationsResult.ValueOrDie();

}


GribReader GribReader::withRepeatableIterator(bool repeatable) {
//TODO - add some validation
this->isRepeatable = repeatable;
return *this;
}

void GribReader::validateConversionFields(std::shared_ptr<arrow::Table> locations, std::string table_name) {
auto table = locations.get();
void GribReader::validateConversionFields(std::shared_ptr<arrow::Table> conversions, std::string table_name) {
auto table = conversions.get();
auto columns = table->ColumnNames();
std::set<std::string> columnsSet(std::make_move_iterator(columns.begin()),
std::make_move_iterator(columns.end()));
Expand Down Expand Up @@ -107,27 +147,28 @@ arrow::ArrayVector* GribReader::castColumn(std::shared_ptr<arrow::Table> locatio
return chunkVector;
}

std::shared_ptr<arrow::Table> GribReader::castConversionFields(std::shared_ptr<arrow::Table> locations, std::string table_name) {
std::shared_ptr<arrow::Table> GribReader::castTableFields(std::shared_ptr<arrow::Table> arrow_table,
std::string table_name,
std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> fieldTypes) {

//Ok this is a PITA - Although there is a .swap method of a column it doesn't work if we want to
//swap the column with a new data type or had an issue with
//TODO see if we can use swap - was it a problem with mixing datatypes which was resolved on 3rd Feb ?
//trying to remove and add wasn't working either so we basically create a new table

auto table = locations.get();
auto table = arrow_table.get();

std::vector<std::shared_ptr<arrow::ChunkedArray>> resultsArray;
arrow::FieldVector fieldVector;

auto fieldTypes = getConversionFieldDefinitions();

for (auto colDetails: fieldTypes) {

auto colName = colDetails.first;
auto colType = colDetails.second;
fieldVector.push_back(arrow::field(colName, colType));

auto chunkVector = castColumn(locations, colName, colType);
auto chunkVector = castColumn(arrow_table, colName, colType);
auto col = table->GetColumnByName(colName).get();
auto chunkedArrayResult = col->Make(*chunkVector, colType);
if(chunkedArrayResult.ok()) {
Expand All @@ -149,19 +190,9 @@ GribReader GribReader::withLocations(std::string path){

//Reads a CSV with the location data and enriches it with a row_number / surrogate_key


std::shared_ptr<arrow::Table> locations = getTableFromCsv(path, arrow::csv::ConvertOptions::Defaults());
return GribReader::withLocations(locations);

//Append an additional column to the table called surrogate key
auto numberOfRows = locations.get()->num_rows();
auto surrogate_columns = createSurrogateKeyCol(numberOfRows);
auto skField = arrow::field("surrogate_key", arrow::uint16());
auto chunkedArray = std::make_shared<arrow::ChunkedArray>(arrow::ChunkedArray(surrogate_columns.ValueOrDie()));
locations = locations.get()->AddColumn(0, skField, chunkedArray).ValueOrDie();


this->shared_locations = locations;
return *this;
}

enum conversionMethods {
Expand Down Expand Up @@ -198,7 +229,7 @@ GribReader GribReader::withConversions(std::shared_ptr<arrow::Table> conversions
//the table should contain 2 columns "lat" and "lon"
validateConversionFields(conversions, " passed conversions via arrow");
std::cout << "Fields validated" << std::endl;
conversions = castConversionFields(conversions, " passed conversions via arrow");
conversions = castTableFields(conversions, " passed conversions via arrow", getConversionFieldDefinitions());

auto rowConversion = ColumnarTableToVector(conversions);

Expand Down
8 changes: 6 additions & 2 deletions src/gribreader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ class GribReader
GribMessage* m_endMessage;
std::shared_ptr<arrow::Table> getTableFromCsv(std::string path, arrow::csv::ConvertOptions convertOptions);
arrow::Result<std::shared_ptr<arrow::Array>> createSurrogateKeyCol(long numberOfRows);
void validateConversionFields(std::shared_ptr<arrow::Table> locations, std::string table_name);
std::shared_ptr<arrow::Table> castConversionFields(std::shared_ptr<arrow::Table> locations, std::string table_name);
void validateConversionFields(std::shared_ptr<arrow::Table> conversions, std::string table_name);
std::shared_ptr<arrow::Table> GribReader::castTableFields(std::shared_ptr<arrow::Table> arrow_table,
std::string table_name,
std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> fieldTypes);
void validateLocationFields(std::shared_ptr<arrow::Table> locations, std::string table_name) ;
std::shared_ptr<arrow::Table> enrichLocationsWithSurrogateKey(std::shared_ptr<arrow::Table> locations) ;
arrow::ArrayVector* castColumn(std::shared_ptr<arrow::Table> locations,
std::string colName,
std::shared_ptr<arrow::DataType> fieldType) ;
Expand Down
2 changes: 2 additions & 0 deletions tests/test_with_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_default_fields_present(self, resource):
assert 'modelNo' in df.columns
assert 'lat' in df.columns
assert 'lon' in df.columns
assert 'surrogate_key' in df.columns
assert 'distance' in df.columns
assert 'nearestlatitude' in df.columns
assert 'nearestlongitude' in df.columns
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_passthrough_columns(self, resource):
print(f"YO I got the following columns {columns}")
assert 'lat' in columns
assert 'lon' in columns
assert 'surrogate_key' in columns
assert 'name' in columns
assert 'awesome_factor' in columns
assert 'beer' in columns
Expand Down

0 comments on commit 6b2b57b

Please sign in to comment.