Skip to content

Commit

Permalink
[Core] Support Arrow zerocopy serialization in object store (#35110)
Browse files Browse the repository at this point in the history
Support Arrow in object store with zerocopy and improve performance.

We made a benchmark under the dataset [NYC TAXI FARE](https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data), which has 8 columns and 55423855 rows in csv, 5.4G on disk.

Here are the results:
| Java to Java | Java Write(ms) | Java Read(ms) |
| :-----: | :----: | :----: |
| Before | 23,637 | 3,162 |
| After | 23,320 | 226 |

| Java to Python | Java Write(ms) | Python Read(ms) |
| :---: | :---: | :---: |
| Before | 28,771 | 2,645 |
| After | 25,864 | 8 |

| Python to Java | Python Write(ms) | Java Read(ms) |
| :---: | :---: | :---: |
| Before | 10,597 | 3,386 |
| After | 5,271 | 3,251 |

| Python to Python | Python Write(ms) | Python Read(ms) |
| :---: | :---: | :---: |
| Before | 9,113 | 988 |
| After | 5,636 | 66 |


Benchmark code:

```python
import ray, raydp, time
from pyarrow import csv
import sys

file_path = "FilePath_/train.csv"
# file_path = "FilePath_/train_tiny.csv"

if __name__ == '__main__':
  ray.init()
  write, read = sys.argv[1], sys.argv[2]
  assert write in ("java", "python") and read in ("java", "python"), "Illegal arguments. Please use java or python"

  spark = raydp.init_spark('benchmark', 10, 5, '2G', configs={"spark.default.parallelism": 50})

  if write == "java":
    df = spark.read.format("csv").option("header", "true") \
            .option("inferSchema", "true") \
            .load(f"file:https://{file_path}")
    print(df.count())
    start = time.time()
    blocks, _ = raydp.spark.dataset._save_spark_df_to_object_store(df, False)
    end = time.time()
    ds = ray.data.from_arrow_refs(blocks)
  elif write == "python":
    table = csv.read_csv(file_path)
    start = time.time()
    ds = ray.data.from_arrow(table)
    end = time.time()
    print(ds.num_blocks())
    ds = ds.repartition(50)

  print(f"{write} writing takes {end - start} seconds.")

  if read == "java":
    start = time.time()
    df = ds.to_spark(spark)
    end = time.time()
    print(df.count())
  elif read == "python":
    start = time.time()
    ray.get(ds.get_internal_block_refs())
    end = time.time()

  print(f"{read} reading takes {end - start} seconds.")

  raydp.stop_spark()
  ray.shutdown()
```
  • Loading branch information
Deegue committed Jun 1, 2023
1 parent 1332133 commit 158c2bf
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 4 deletions.
4 changes: 4 additions & 0 deletions java/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ define_java_module(
"@maven//:commons_io_commons_io",
"@maven//:de_ruedigermoeller_fst",
"@maven//:net_java_dev_jna_jna",
"@maven//:org_apache_arrow_arrow_memory_core",
"@maven//:org_apache_arrow_arrow_memory_unsafe",
"@maven//:org_apache_arrow_arrow_vector",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_logging_log4j_log4j_api",
"@maven//:org_apache_logging_log4j_log4j_core",
Expand All @@ -117,6 +120,7 @@ define_java_module(
"@maven//:com_sun_xml_bind_jaxb_impl",
"@maven//:commons_io_commons_io",
"@maven//:javax_xml_bind_jaxb_api",
"@maven//:org_apache_arrow_arrow_vector",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_logging_log4j_log4j_api",
"@maven//:org_apache_logging_log4j_log4j_core",
Expand Down
3 changes: 3 additions & 0 deletions java/dependencies.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def gen_java_deps():
"org.slf4j:slf4j-api:1.7.25",
"com.lmax:disruptor:3.3.4",
"net.java.dev.jna:jna:5.8.0",
"org.apache.arrow:arrow-memory-core:5.0.0",
"org.apache.arrow:arrow-memory-unsafe:5.0.0",
"org.apache.arrow:arrow-vector:5.0.0",
"org.apache.httpcomponents.client5:httpclient5:5.0.3",
"org.apache.httpcomponents.core5:httpcore5:5.0.2",
"org.apache.httpcomponents.client5:httpclient5-fluent:5.0.3",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import io.ray.runtime.generated.Common.ErrorType;
import io.ray.runtime.serializer.RayExceptionSerializer;
import io.ray.runtime.serializer.Serializer;
import io.ray.runtime.util.ArrowUtil;
import io.ray.runtime.util.IdUtil;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.commons.lang3.tuple.Pair;

/**
Expand Down Expand Up @@ -49,6 +51,7 @@ public class ObjectSerializer {
private static final byte[] TASK_EXECUTION_EXCEPTION_META =
String.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();

public static final byte[] OBJECT_METADATA_TYPE_ARROW = "ARROW".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
Expand Down Expand Up @@ -80,7 +83,9 @@ public static Object deserialize(

if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) {
if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_ARROW) == 0) {
return ArrowUtil.deserialize(data);
} else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) {
if (objectType == ByteBuffer.class) {
return ByteBuffer.wrap(data);
}
Expand Down Expand Up @@ -136,6 +141,10 @@ public static NativeRayObject serialize(Object object) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof VectorSchemaRoot) {
// serialize arrow data using IPC Stream format
byte[] bytes = ArrowUtil.serialize((VectorSchemaRoot) object);
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_ARROW);
} else if (object instanceof ByteBuffer) {
// Serialize ByteBuffer to raw bytes.
ByteBuffer buffer = (ByteBuffer) object;
Expand Down
61 changes: 61 additions & 0 deletions java/runtime/src/main/java/io/ray/runtime/util/ArrowUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package io.ray.runtime.util;

import io.ray.api.exception.RayException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.nio.channels.Channels;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageChannelReader;
import org.apache.arrow.vector.ipc.message.MessageResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;

/** Helper method for serialize and deserialize arrow data. */
public class ArrowUtil {
public static final RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE);

/**
* Deserialize data in byte array to arrow data format.
*
* @return The vector schema root of arrow.
*/
public static VectorSchemaRoot deserialize(byte[] data) {
try (MessageChannelReader reader =
new MessageChannelReader(
new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data))), rootAllocator)) {
MessageResult result = reader.readNext();
Schema schema = MessageSerializer.deserializeSchema(result.getMessage());
VectorSchemaRoot root = VectorSchemaRoot.create(schema, rootAllocator);
VectorLoader loader = new VectorLoader(root);
result = reader.readNext();
ArrowRecordBatch batch =
MessageSerializer.deserializeRecordBatch(result.getMessage(), result.getBodyBuffer());
loader.load(batch);
return root;
} catch (Exception e) {
throw new RayException("Failed to deserialize Arrow data", e.getCause());
}
}

/**
* Serialize data from arrow data format to byte array.
*
* @return The byte array of data.
*/
public static byte[] serialize(VectorSchemaRoot root) {
try (ByteArrayOutputStream sink = new ByteArrayOutputStream();
ArrowStreamWriter writer = new ArrowStreamWriter(root, null, sink)) {
writer.start();
writer.writeBatch();
writer.end();
return sink.toByteArray();
} catch (Exception e) {
throw new RayException("Failed to serialize Arrow data", e.getCause());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package io.ray.test;

import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyFunction;
import io.ray.runtime.util.ArrowUtil;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Test(groups = {"cluster"})
public class CrossLanguageObjectStoreTest extends BaseTest {

private static final String PYTHON_MODULE = "test_cross_language_invocation";
private static final int vecSize = 5;

@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir =
new File(
System.getProperty("java.io.tmpdir")
+ File.separator
+ "ray_cross_language_object_store_test");
FileUtils.deleteQuietly(tempDir);
tempDir.mkdirs();
tempDir.deleteOnExit();

// Write the test Python file to the temp dir.
InputStream in =
CrossLanguageObjectStoreTest.class.getResourceAsStream(
File.separator + PYTHON_MODULE + ".py");
File pythonFile = new File(tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
try {
FileUtils.copyInputStreamToFile(in, pythonFile);
} catch (IOException e) {
throw new RuntimeException(e);
}

System.setProperty(
"ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator + tempDir.getAbsolutePath());
}

@Test
public void testPythonPutAndJavaGet() {
ObjectRef<VectorSchemaRoot> res =
Ray.task(PyFunction.of(PYTHON_MODULE, "py_put_into_object_store", VectorSchemaRoot.class))
.remote();
VectorSchemaRoot root = res.get();
BigIntVector newVector = (BigIntVector) root.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}

@Test
public void testJavaPutAndPythonGet() {
BigIntVector vector = new BigIntVector("ArrowBigIntVector", ArrowUtil.rootAllocator);
vector.setValueCount(vecSize);
for (int i = 0; i < vecSize; i++) {
vector.setSafe(i, i);
}
List<Field> fields = Arrays.asList(vector.getField());
List<FieldVector> vectors = Arrays.asList(vector);
VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors);
ObjectRef<VectorSchemaRoot> obj = Ray.put(root);

ObjectRef<VectorSchemaRoot> res =
Ray.task(
PyFunction.of(
PYTHON_MODULE, "py_object_store_get_and_check", VectorSchemaRoot.class),
obj)
.remote();

VectorSchemaRoot newRoot = res.get();
BigIntVector newVector = (BigIntVector) newRoot.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}
}
25 changes: 25 additions & 0 deletions java/test/src/main/java/io/ray/test/ObjectStoreTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import io.ray.api.Ray;
import io.ray.api.exception.RayTaskException;
import io.ray.api.exception.UnreconstructableException;
import io.ray.runtime.util.ArrowUtil;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -75,6 +81,25 @@ public void testGetMultipleObjects() {
Assert.assertEquals(ints, Ray.get(refs));
}

@Test
public void testArrowObjects() {
final int vecSize = 10;
IntVector vector = new IntVector("ArrowIntVector", ArrowUtil.rootAllocator);
vector.setValueCount(vecSize);
for (int i = 0; i < vecSize; i++) {
vector.setSafe(i, i);
}
List<Field> fields = Arrays.asList(vector.getField());
List<FieldVector> vectors = Arrays.asList(vector);
VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors);
ObjectRef<VectorSchemaRoot> obj = Ray.put(root);
VectorSchemaRoot newRoot = obj.get();
IntVector newVector = (IntVector) newRoot.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}

@Test(groups = {"cluster"})
public void testOwnerAssignWhenPut() throws Exception {
// This test should align with test_owner_assign_when_put in Python
Expand Down
34 changes: 34 additions & 0 deletions java/test/src/main/resources/test_cross_language_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import asyncio

import pyarrow as pa

import ray


Expand Down Expand Up @@ -189,3 +191,35 @@ def py_func_call_java_overloaded_method():
result = ray.get([ref1, ref2])
assert result == ["first", "firstsecond"]
return True


@ray.remote
def py_put_into_object_store():
column_values = [0, 1, 2, 3, 4]
column_array = pa.array(column_values)
table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"])
return table


@ray.remote
def py_object_store_get_and_check(table):
column_values = [0, 1, 2, 3, 4]
column_array = pa.array(column_values)
expected_table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"])

for column_name in table.column_names:
column1 = table[column_name]
column2 = expected_table[column_name]

indices = pa.compute.equal(column1, column2).to_pylist()
differing_rows = [i for i, index in enumerate(indices) if not index]

if differing_rows:
print(f"Differences in column '{column_name}':")
for row in differing_rows:
value1 = column1[row].as_py()
value2 = column2[row].as_py()
print(f"Row {row}: {value1} != {value2}")
raise RuntimeError("Check failed, two tables are not equal!")

return table
2 changes: 2 additions & 0 deletions python/ray/_private/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def env_set_by_user(key):
OBJECT_METADATA_TYPE_PYTHON = b"PYTHON"
# A constant used as object metadata to indicate the object is raw bytes.
OBJECT_METADATA_TYPE_RAW = b"RAW"
# A constant used as object metadata to indicate the object is arrow data.
OBJECT_METADATA_TYPE_ARROW = b"ARROW"

# A constant used as object metadata to indicate the object is an actor handle.
# This value should be synchronized with the Java definition in
Expand Down
23 changes: 21 additions & 2 deletions python/ray/_private/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray.cloudpickle as pickle
from ray._private import ray_constants
from ray._raylet import (
ArrowSerializedObject,
MessagePackSerializedObject,
MessagePackSerializer,
ObjectRefGenerator,
Expand Down Expand Up @@ -47,6 +48,11 @@
from ray.util import serialization_addons
from ray.util import inspect_serializability

try:
import pyarrow as pa
except ImportError:
pa = None

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -270,6 +276,12 @@ def _deserialize_object(self, data, metadata, object_ref):
if data is None:
return b""
return data.to_pybytes()
elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW:
assert (
pa is not None
), "pyarrow should be imported while deserializing arrow objects"
reader = pa.BufferReader(data)
return pa.ipc.open_stream(reader).read_all()
elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
obj = self._deserialize_msgpack_data(data, metadata_fields)
return _actor_handle_deserializer(obj)
Expand Down Expand Up @@ -461,5 +473,12 @@ def serialize(self, value):
# use a special metadata to indicate it's raw binary. So
# that this object can also be read by Java.
return RawSerializedObject(value)
else:
return self._serialize_to_msgpack(value)

# Check whether arrow is installed. If so, use Arrow IPC format
# to serialize this object, then it can also be read by Java.
if pa is not None and (
isinstance(value, pa.Table) or isinstance(value, pa.RecordBatch)
):
return ArrowSerializedObject(value)

return self._serialize_to_msgpack(value)
4 changes: 3 additions & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table":
cols = table.columns
new_cols = []
for col in cols:
if _is_column_extension_type(col):
if col.num_chunks == 0:
arr = pyarrow.chunked_array([], type=col.type)
elif _is_column_extension_type(col):
# Extension arrays don't support concatenation.
arr = _concatenate_extension_column(col)
else:
Expand Down
Loading

0 comments on commit 158c2bf

Please sign in to comment.