Skip to content

Commit

Permalink
[FLINK-13811][python] Support converting flink Table to pandas DataFrame
Browse files Browse the repository at this point in the history
This closes apache#12148.
  • Loading branch information
dianfu committed May 15, 2020
1 parent 2e6acb6 commit d417889
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 4 deletions.
23 changes: 22 additions & 1 deletion docs/dev/table/python/conversion_of_pandas.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ It supports to convert between PyFlink Table and Pandas DataFrame.

## Convert Pandas DataFrame to PyFlink Table

It supports to create a PyFlink Table from a Pandas DataFrame. Internally, it will serialize the Pandas DataFrame
It supports creating a PyFlink Table from a Pandas DataFrame. Internally, it will serialize the Pandas DataFrame
using Arrow columnar format at client side and the serialized data will be processed and deserialized in Arrow source
during execution. The Arrow source could also be used in streaming jobs and it will properly handle the checkpoint
and provides the exactly once guarantees.
Expand Down Expand Up @@ -57,3 +57,24 @@ table = t_env.from_pandas(pdf,
DataTypes.ROW([DataTypes.FIELD("f0", DataTypes.DOUBLE()),
DataTypes.FIELD("f1", DataTypes.DOUBLE())])
{% endhighlight %}

## Convert PyFlink Table to Pandas DataFrame

It also supports converting a PyFlink Table to a Pandas DataFrame. Internally, it will materialize the results of the
table and serialize them into multiple Arrow batches of Arrow columnar format at client side. The maximum Arrow batch size
is determined by the config option [python.fn-execution.arrow.batch.size]({{ site.baseurl }}/dev/table/python/python_config.html#python-fn-execution-arrow-batch-size).
The serialized data will then be converted to Pandas DataFrame.

The following example shows how to convert a PyFlink Table to a Pandas DataFrame:

{% highlight python %}
import pandas as pd
import numpy as np

# Create a PyFlink Table
pdf = pd.DataFrame(np.random.rand(1000, 2))
table = t_env.from_pandas(pdf, ["a", "b"]).filter("a > 0.5")

# Convert the PyFlink Table to a Pandas DataFrame
pdf = table.to_pandas()
{% endhighlight %}
23 changes: 22 additions & 1 deletion docs/dev/table/python/conversion_of_pandas.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ It supports to convert between PyFlink Table and Pandas DataFrame.

## Convert Pandas DataFrame to PyFlink Table

It supports to create a PyFlink Table from a Pandas DataFrame. Internally, it will serialize the Pandas DataFrame
It supports creating a PyFlink Table from a Pandas DataFrame. Internally, it will serialize the Pandas DataFrame
using Arrow columnar format at client side and the serialized data will be processed and deserialized in Arrow source
during execution. The Arrow source could also be used in streaming jobs and it will properly handle the checkpoint
and provides the exactly once guarantees.
Expand Down Expand Up @@ -57,3 +57,24 @@ table = t_env.from_pandas(pdf,
DataTypes.ROW([DataTypes.FIELD("f0", DataTypes.DOUBLE()),
DataTypes.FIELD("f1", DataTypes.DOUBLE())])
{% endhighlight %}

## Convert PyFlink Table to Pandas DataFrame

It also supports converting a PyFlink Table to a Pandas DataFrame. Internally, it will materialize the results of the
table and serialize them into multiple Arrow batches of Arrow columnar format at client side. The maximum Arrow batch size
is determined by the config option [python.fn-execution.arrow.batch.size]({{ site.baseurl }}/zh/dev/table/python/python_config.html#python-fn-execution-arrow-batch-size).
The serialized data will then be converted to Pandas DataFrame.

The following example shows how to convert a PyFlink Table to a Pandas DataFrame:

{% highlight python %}
import pandas as pd
import numpy as np

# Create a PyFlink Table
pdf = pd.DataFrame(np.random.rand(1000, 2))
table = t_env.from_pandas(pdf, ["a", "b"]).filter("a > 0.5")

# Convert the PyFlink Table to a Pandas DataFrame
pdf = table.to_pandas()
{% endhighlight %}
26 changes: 26 additions & 0 deletions flink-python/pyflink/table/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import io

from pyflink.serializers import Serializer
from pyflink.table.utils import arrow_to_pandas, pandas_to_arrow

Expand Down Expand Up @@ -51,3 +53,27 @@ def load_from_stream(self, stream):
reader = pa.ipc.open_stream(stream)
for batch in reader:
yield arrow_to_pandas(self._timezone, self._field_types, [batch])

def load_from_iterator(self, itor):
class IteratorIO(io.RawIOBase):
def __init__(self, itor):
super(IteratorIO, self).__init__()
self.itor = itor
self.leftover = None

def readable(self):
return True

def readinto(self, b):
output_buffer_len = len(b)
input = self.leftover or (self.itor.next() if self.itor.hasNext() else None)
if input is None:
return 0
output, self.leftover = input[:output_buffer_len], input[output_buffer_len:]
b[:len(output)] = output
return len(output)
import pyarrow as pa
reader = pa.ipc.open_stream(
io.BufferedReader(IteratorIO(itor), buffer_size=io.DEFAULT_BUFFER_SIZE))
for batch in reader:
yield batch
44 changes: 44 additions & 0 deletions flink-python/pyflink/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import warnings

from py4j.java_gateway import get_method

from pyflink.java_gateway import get_gateway
from pyflink.table.serializers import ArrowSerializer
from pyflink.table.table_schema import TableSchema
from pyflink.table.types import create_arrow_schema
from pyflink.table.utils import tz_convert_from_internal

from pyflink.util.utils import to_jarray
from pyflink.util.utils import to_j_explain_detail_arr
Expand Down Expand Up @@ -692,6 +696,46 @@ def insert_into(self, table_path):
DeprecationWarning)
self._j_table.insertInto(table_path)

def to_pandas(self):
"""
Converts the table to a pandas DataFrame.
Example:
::
>>> pdf = pd.DataFrame(np.random.rand(1000, 2))
>>> table = table_env.from_pandas(pdf, ["a", "b"])
>>> table.filter("a > 0.5").to_pandas()
:return: the result pandas DataFrame.
"""
gateway = get_gateway()
max_arrow_batch_size = self._j_table.getTableEnvironment().getConfig().getConfiguration()\
.getInteger(gateway.jvm.org.apache.flink.python.PythonOptions.MAX_ARROW_BATCH_SIZE)
batches = gateway.jvm.org.apache.flink.table.runtime.arrow.ArrowUtils\
.collectAsPandasDataFrame(self._j_table, max_arrow_batch_size)
if batches.hasNext():
import pytz
timezone = pytz.timezone(
self._j_table.getTableEnvironment().getConfig().getLocalTimeZone().getId())
serializer = ArrowSerializer(
create_arrow_schema(self.get_schema().get_field_names(),
self.get_schema().get_field_data_types()),
self.get_schema().to_row_data_type(),
timezone)
import pyarrow as pa
table = pa.Table.from_batches(serializer.load_from_iterator(batches))
pdf = table.to_pandas()

schema = self.get_schema()
for field_name in schema.get_field_names():
pdf[field_name] = tz_convert_from_internal(
pdf[field_name], schema.get_field_data_type(field_name), timezone)
return pdf
else:
import pandas as pd
return pd.DataFrame.from_records([], columns=self.get_schema().get_field_names())

def get_schema(self):
"""
Returns the :class:`~pyflink.table.TableSchema` of this table.
Expand Down
13 changes: 13 additions & 0 deletions flink-python/pyflink/table/tests/test_pandas_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import datetime
import decimal

from pandas.util.testing import assert_frame_equal

from pyflink.table.types import DataTypes, Row
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase, \
Expand Down Expand Up @@ -130,6 +132,17 @@ def test_from_pandas(self):
"1970-01-01 00:00:00.123,[hello, 中文],1,hello,"
"1970-01-01 00:00:00.123,[1, 2]"])

def test_to_pandas(self):
table = self.t_env.from_pandas(self.pdf, self.data_type)
result_pdf = table.to_pandas()
self.assertEqual(2, len(result_pdf))
assert_frame_equal(self.pdf, result_pdf)

def test_empty_to_pandas(self):
table = self.t_env.from_pandas(self.pdf, self.data_type)
pdf = table.filter("f1 < 0").to_pandas()
self.assertTrue(pdf.empty)


class StreamPandasConversionTests(PandasConversionITTests,
PyFlinkStreamTableTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,19 @@

import org.apache.flink.annotation.Internal;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.internal.TableEnvImpl;
import org.apache.flink.table.api.internal.TableEnvironmentImpl;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.data.ArrayData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.util.DataFormatConverters;
import org.apache.flink.table.data.vector.ColumnVector;
import org.apache.flink.table.delegation.Planner;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.sinks.SelectTableSinkSchemaConverter;
import org.apache.flink.table.runtime.arrow.readers.ArrayFieldReader;
import org.apache.flink.table.runtime.arrow.readers.ArrowFieldReader;
import org.apache.flink.table.runtime.arrow.readers.BigIntFieldReader;
Expand Down Expand Up @@ -113,6 +123,7 @@
import org.apache.flink.types.Row;

import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
Expand All @@ -135,6 +146,7 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageMetadataResult;
Expand All @@ -146,7 +158,10 @@
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.FileInputStream;
import java.io.IOException;
Expand All @@ -157,6 +172,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -166,6 +182,8 @@
@Internal
public final class ArrowUtils {

private static final Logger LOG = LoggerFactory.getLogger(ArrowUtils.class);

private static RootAllocator rootAllocator;

public static synchronized RootAllocator getRootAllocator() {
Expand Down Expand Up @@ -588,6 +606,102 @@ private static void readFully(ReadableByteChannel channel, ByteBuffer dst) throw
}
}

/**
* Convert Flink table to Pandas DataFrame.
*/
public static CustomIterator<byte[]> collectAsPandasDataFrame(Table table, int maxArrowBatchSize) throws Exception {
BufferAllocator allocator = getRootAllocator().newChildAllocator("collectAsPandasDataFrame", 0, Long.MAX_VALUE);
RowType rowType = (RowType) table.getSchema().toRowDataType().getLogicalType();
VectorSchemaRoot root = VectorSchemaRoot.create(ArrowUtils.toArrowSchema(rowType), allocator);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(root, null, baos);
arrowStreamWriter.start();

ArrowWriter arrowWriter;
Iterator<Row> results = table.execute().collect();
Iterator convertedResults;
if (isBlinkPlanner(table)) {
arrowWriter = createRowDataArrowWriter(root, rowType);
convertedResults = new Iterator<RowData>() {
@Override
public boolean hasNext() {
return results.hasNext();
}

@Override
public RowData next() {
// The SelectTableSink of blink planner will convert the table schema and we
// need to keep the table schema used here be consistent with the converted table schema
TableSchema convertedTableSchema =
SelectTableSinkSchemaConverter.changeDefaultConversionClass(table.getSchema());
DataFormatConverters.DataFormatConverter converter =
DataFormatConverters.getConverterForDataType(convertedTableSchema.toRowDataType());
return (RowData) converter.toInternal(results.next());
}
};
} else {
arrowWriter = createRowArrowWriter(root, rowType);
convertedResults = results;
}

return new CustomIterator<byte[]>() {
@Override
public boolean hasNext() {
return convertedResults.hasNext();
}

@Override
public byte[] next() {
try {
int i = 0;
while (convertedResults.hasNext() && i < maxArrowBatchSize) {
i++;
arrowWriter.write(convertedResults.next());
}
arrowWriter.finish();
arrowStreamWriter.writeBatch();
return baos.toByteArray();
} catch (Throwable t) {
String msg = "Failed to serialize the data of the table";
LOG.error(msg, t);
throw new RuntimeException(msg, t);
} finally {
arrowWriter.reset();
baos.reset();

if (!hasNext()) {
root.close();
allocator.close();
}
}
}
};
}

private static boolean isBlinkPlanner(Table table) {
TableEnvironment tableEnv = ((TableImpl) table).getTableEnvironment();
if (tableEnv instanceof TableEnvImpl) {
return false;
} else if (tableEnv instanceof TableEnvironmentImpl) {
Planner planner = ((TableEnvironmentImpl) tableEnv).getPlanner();
return planner instanceof PlannerBase;
} else {
throw new RuntimeException(String.format(
"Could not determine the planner type for table environment class %s", tableEnv.getClass()));
}
}

/**
* A custom iterator to bypass the Py4J Java collection as the next method of
* py4j.java_collections.JavaIterator will eat all the exceptions thrown in Java
* which makes it difficult to debug in case of errors.
*/
private interface CustomIterator<T> {
boolean hasNext();

T next();
}

private static class LogicalTypeToArrowTypeConverter extends LogicalTypeDefaultVisitor<ArrowType> {

private static final LogicalTypeToArrowTypeConverter INSTANCE = new LogicalTypeToArrowTypeConverter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
/**
* An utility class that provides abilities to change {@link TableSchema}.
*/
class SelectTableSinkSchemaConverter {
public class SelectTableSinkSchemaConverter {

/**
* Change to default conversion class and build a new {@link TableSchema}.
*/
static TableSchema changeDefaultConversionClass(TableSchema tableSchema) {
public static TableSchema changeDefaultConversionClass(TableSchema tableSchema) {
DataType[] oldTypes = tableSchema.getFieldDataTypes();
String[] fieldNames = tableSchema.getFieldNames();

Expand Down

0 comments on commit d417889

Please sign in to comment.