Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hugopendlebury committed Feb 3, 2024
1 parent d02d150 commit b817938
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 103 deletions.
10 changes: 6 additions & 4 deletions src/gribreader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ arrow::ArrayVector* GribReader::castColumn(std::shared_ptr<arrow::Table> locatio
auto converted = result.ValueOrDie();
chunkVector->emplace_back(converted);
} else{
std::string errMsg = "Unable to cast conversion column " + colName;
std::string errMsg = "Unable to cast conversion column " + colName + " " + result.status().message();
throw InvalidSchemaException(errMsg);
}
}
Expand Down Expand Up @@ -399,17 +399,19 @@ std::shared_ptr<arrow::Table> GribReader::getTableFromCsv(std::string path, arro
if (table.ok()) {
return table.ValueOrDie();
} else {
std::string errDetails = "Error reading results into arrow table is this a valid CSV ? ";
std::string errDetails = "Error reading results into arrow table is this a valid CSV ? "
+ " " + table.status().message();
throw InvalidCSVException(errDetails );
}

} else {
std::string errDetails = "Unable to create arrow CSV table reader for file " + path;
std::string errDetails = "Unable to create arrow CSV table reader for file " + path +
" " + csv_reader.status().message();
throw UnableToCreateArrowTableReaderException(errDetails);
}

} else {
throw NoSuchLocationsFileException(path);
throw NoSuchLocationsFileException(path + " " + infile.status().message());
}

}
Expand Down
195 changes: 96 additions & 99 deletions tests/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import os
import sys
import pyarrow
import pyarrow as pa

class TestConversions:
def get_2t_df(self, reader):
Expand All @@ -24,10 +24,35 @@ def get_tcc_df(self, reader):

df = pl.concat(payloads)
return df

def getConversionFields(self, replacements: dict):
return {
"parameterId": [167],
"addition_value": [None],
"subtraction_value": [None],
"multiplication_value": [None],
"division_value": [None],
"ceiling_value": [None],
} | replacements

def getCastConversions(self, replacements: dict) -> pa.table:
fields = self.getConversionFields(replacements)

return (
pl.DataFrame(
fields
).select(
pl.col("parameterId"),
pl.col("addition_value").cast(pl.Float64),
pl.col("subtraction_value").cast(pl.Float64),
pl.col("multiplication_value").cast(pl.Float64),
pl.col("division_value").cast(pl.Float64),
pl.col("ceiling_value").cast(pl.Float64),
).to_arrow()
)

@pytest.mark.skip(reason="Not sure how to test this yet")
def test_validation(self, resource):
from gribtoarrow import GribReader
from gribtoarrow import GribReader, InvalidSchemaException

locations = pl.DataFrame({"lat": [51.5054], "lon": [-0.027176]}).to_arrow()

Expand All @@ -52,11 +77,26 @@ def test_validation(self, resource):
)
).to_arrow()

(
GribReader(f"{resource}{os.sep}gep01.t00z.pgrb2a.0p50.f003")
.withLocations(locations)
.withConversions(conversions)
)
with pytest.raises(InvalidSchemaException):
(
GribReader(f"{resource}{os.sep}gep01.t00z.pgrb2a.0p50.f003")
.withLocations(locations)
.withConversions(conversions)
)


def test_uncastable_field(self, resource):
from gribtoarrow import GribReader, InvalidSchemaException

conversions = pl.DataFrame(
data = self.getConversionFields({"subtraction_value": ["hugo"]})
).to_arrow()

with pytest.raises(InvalidSchemaException):
(
GribReader(f"{resource}{os.sep}gep01.t00z.pgrb2a.0p50.f003")
.withConversions(conversions)
)

def test_addition(self, resource):
# In this test we will read a Grib file which was downloaded from
Expand All @@ -82,25 +122,7 @@ def test_addition(self, resource):

# In reality the use would probably store these in a config file / CSV

conversions = (
pl.DataFrame(
{
"parameterId": [167],
"addition_value": [-273.15],
"subtraction_value": [None],
"multiplication_value": [None],
"division_value": [None],
"ceiling_value": [None],
}
).select(
pl.col("parameterId"),
pl.col("addition_value").cast(pl.Float64),
pl.col("subtraction_value").cast(pl.Float64),
pl.col("multiplication_value").cast(pl.Float64),
pl.col("division_value").cast(pl.Float64),
pl.col("ceiling_value").cast(pl.Float64),
)
).to_arrow()
conversions = self.getCastConversions({"addition_value": [-273.15]})

reader = (
GribReader(f"{resource}{os.sep}gep01.t00z.pgrb2a.0p50.f003")
Expand Down Expand Up @@ -134,27 +156,7 @@ def test_subtraction(self, resource):

assert round(raw_df["value"].to_list()[0], 2) == round(280.128, 2)

# In reality the use would probably store these in a config file / CSV

conversions = (
pl.DataFrame(
{
"parameterId": [167],
"addition_value": [None],
"subtraction_value": [273.15],
"multiplication_value": [None],
"division_value": [None],
"ceiling_value": [None],
}
).select(
pl.col("parameterId"),
pl.col("addition_value").cast(pl.Float64),
pl.col("subtraction_value").cast(pl.Float64),
pl.col("multiplication_value").cast(pl.Float64),
pl.col("division_value").cast(pl.Float64),
pl.col("ceiling_value").cast(pl.Float64),
)
).to_arrow()
conversions = self.getCastConversions({"subtraction_value": [273.15]})

reader = (
GribReader(str(resource) + "/gep01.t00z.pgrb2a.0p50.f003")
Expand Down Expand Up @@ -189,27 +191,10 @@ def test_division(self, resource):

assert raw_df["value"].to_list()[0] == 100.0

# In reality the use would probably store these in a config file / CSV

conversions = (
pl.DataFrame(
{
"parameterId": [228164],
"addition_value": [None],
"subtraction_value": [None],
"multiplication_value": [None],
"division_value": [100],
"ceiling_value": [None],
}
).select(
pl.col("parameterId"),
pl.col("addition_value").cast(pl.Float64),
pl.col("subtraction_value").cast(pl.Float64),
pl.col("multiplication_value").cast(pl.Float64),
pl.col("division_value").cast(pl.Float64),
pl.col("ceiling_value").cast(pl.Float64),
)
).to_arrow()
conversions = self.getCastConversions({
"parameterId": [228164],
"division_value": [100]}
)

reader = (
GribReader(f"{resource}{os.sep}gep01.t00z.pgrb2a.0p50.f003")
Expand Down Expand Up @@ -244,27 +229,10 @@ def test_multiplication(self, resource):

assert raw_df["value"].to_list()[0] == 100.0

# In reality the use would probably store these in a config file / CSV

conversions = (
pl.DataFrame(
{
"parameterId": [228164],
"addition_value": [None],
"subtraction_value": [None],
"multiplication_value": [0.01],
"division_value": [None],
"ceiling_value": [None],
}
).select(
pl.col("parameterId"),
pl.col("addition_value").cast(pl.Float64),
pl.col("subtraction_value").cast(pl.Float64),
pl.col("multiplication_value").cast(pl.Float64),
pl.col("division_value").cast(pl.Float64),
pl.col("ceiling_value").cast(pl.Float64),
)
).to_arrow()
conversions = self.getCastConversions({
"parameterId": [228164],
"multiplication_value": [0.01]}
)

reader = (
GribReader(str(resource) + "/gep01.t00z.pgrb2a.0p50.f003")
Expand All @@ -276,7 +244,6 @@ def test_multiplication(self, resource):

assert converted_df["value"].to_list()[0] == 1

pytest.mark.skip(reason="Complete validation")
def test_csv_multiplication(self, resource):
# In this test we will read a Grib file which was downloaded from
# NOAA.
Expand All @@ -302,15 +269,13 @@ def test_csv_multiplication(self, resource):

# In reality the use would probably store these in a config file / CSV

conversions = self.getConversionFields({
"parameterId": [228164],
"multiplication_value": [0.01]}
)

df = pl.DataFrame(
{
"parameterId": [228164],
"addition_value": [None],
"subtraction_value": [None],
"multiplication_value": [0.01],
"division_value": [None],
"ceiling_value": [None],
}
data = conversions
).write_csv(str(resource) + "/test_conversions.csv")

reader = (
Expand All @@ -323,3 +288,35 @@ def test_csv_multiplication(self, resource):

assert converted_df["value"].to_list()[0] == 1

def test_casting(self, resource):

from gribtoarrow import GribReader

# This is the latitude / longitude of Canary wharf
locations = pl.DataFrame({"lat": [51.5054], "lon": [-0.027176]}).to_arrow()

raw_results_reader = GribReader(
str(resource) + "/gep01.t00z.pgrb2a.0p50.f003"
).withLocations(locations)

raw_df = self.get_2t_df(raw_results_reader)

assert round(raw_df["value"].to_list()[0], 2) == round(280.128, 2)

conversions = self.getConversionFields({"division_value": [10]})

conversions = (
pl.DataFrame(
conversions
)
).to_arrow()

reader = (
GribReader(f"{resource}{os.sep}gep01.t00z.pgrb2a.0p50.f003")
.withLocations(locations)
.withConversions(conversions)
)

converted_df = self.get_2t_df(reader)

assert round(converted_df["value"].to_list()[0], 2) == round(28.012, 2)

0 comments on commit b817938

Please sign in to comment.