Skip to content

Commit

Permalink
[Data] Add Dataset.write_sql (ray-project#38544)
Browse files Browse the repository at this point in the history
Writing data back to databases is common for many applications like LLMs. For example, you might want to write vector indices back to a database like https://github.com/pgvector/pgvector. To support this use case, this PR adds an API to write Datasets to SQL databases.

Signed-off-by: Balaji Veeramani <[email protected]>
Signed-off-by: e428265 <[email protected]>
  • Loading branch information
bveeramani authored and arvind-chandra committed Aug 31, 2023
1 parent b299ea2 commit 2c8036c
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 18 deletions.
64 changes: 64 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
)
from ray.data.datasource import (
BlockWritePathProvider,
Connection,
CSVDatasource,
Datasource,
DefaultBlockWritePathProvider,
Expand All @@ -123,6 +124,7 @@
NumpyDatasource,
ParquetDatasource,
ReadTask,
SQLDatasource,
TFRecordDatasource,
WriteResult,
)
Expand Down Expand Up @@ -3215,6 +3217,68 @@ def write_numpy(
block_path_provider=block_path_provider,
)

@ConsumptionAPI
def write_sql(
self,
sql: str,
connection_factory: Callable[[], Connection],
ray_remote_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Write to a database that provides a
`Python DB API2-compliant <https://peps.python.org/pep-0249/>`_ connector.
.. note::
This method writes data in parallel using the DB API2 ``executemany``
method. To learn more about this method, see
`PEP 249 <https://peps.python.org/pep-0249/#executemany>`_.
Examples:
.. testcode::
import sqlite3
import ray
connection = sqlite3.connect("example.db")
connection.cursor().execute("CREATE TABLE movie(title, year, score)")
dataset = ray.data.from_items([
{"title": "Monty Python and the Holy Grail", "year": 1975, "score": 8.2},
{"title": "And Now for Something Completely Different", "year": 1971, "score": 7.5}
])
dataset.write_sql(
"INSERT INTO movie VALUES(?, ?, ?)", lambda: sqlite3.connect("example.db")
)
result = connection.cursor().execute("SELECT * FROM movie ORDER BY year")
print(result.fetchall())
.. testoutput::
[('And Now for Something Completely Different', 1971, 7.5), ('Monty Python and the Holy Grail', 1975, 8.2)]
.. testcode::
:hide:
import os
os.remove("example.db")
Arguments:
sql: An ``INSERT INTO`` statement that specifies the table to write to. The
number of parameters must match the number of columns in the table.
connection_factory: A function that takes no arguments and returns a
Python DB API2
`Connection object <https://peps.python.org/pep-0249/#connection-objects>`_.
ray_remote_args: Keyword arguments passed to :meth:`~ray.remote` in the
write tasks.
""" # noqa: E501
self.write_datasource(
SQLDatasource(connection_factory),
ray_remote_args=ray_remote_args,
sql=sql,
)

@ConsumptionAPI
def write_mongo(
self,
Expand Down
44 changes: 35 additions & 9 deletions python/ray/data/datasource/sql_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from contextlib import contextmanager
from typing import Any, Callable, Iterable, Iterator, List, Optional

from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor, BlockMetadata
from ray.data.datasource.datasource import Datasource, Reader, ReadTask
from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
from ray.util.annotations import PublicAPI

Connection = Any # A Python DB API2-compliant `Connection` object.
Expand All @@ -23,12 +24,38 @@ def _cursor_to_block(cursor) -> Block:

@PublicAPI(stability="alpha")
class SQLDatasource(Datasource):

_MAX_ROWS_PER_WRITE = 128

def __init__(self, connection_factory: Callable[[], Connection]):
self.connection_factory = connection_factory

def create_reader(self, sql: str) -> "Reader":
return _SQLReader(sql, self.connection_factory)

def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
sql: str,
) -> WriteResult:
with _connect(self.connection_factory) as cursor:
for block in blocks:
block_accessor = BlockAccessor.for_block(block)

values = []
for row in block_accessor.iter_rows(public_row_format=False):
values.append(tuple(row.values()))
assert len(values) <= self._MAX_ROWS_PER_WRITE, len(values)
if len(values) == self._MAX_ROWS_PER_WRITE:
cursor.executemany(sql, values)
values = []

if values:
cursor.executemany(sql, values)

return "ok"


def _check_connection_is_dbapi2_compliant(connection) -> None:
for attr in "close", "commit", "cursor":
Expand All @@ -44,7 +71,7 @@ def _check_connection_is_dbapi2_compliant(connection) -> None:
def _check_cursor_is_dbapi2_compliant(cursor) -> None:
# These aren't all the methods required by the specification, but it's all the ones
# we care about.
for attr in "execute", "fetchone", "fetchall", "description":
for attr in "execute", "executemany", "fetchone", "fetchall", "description":
if not hasattr(cursor, attr):
raise ValueError(
"Your database connector created a `Cursor` object without a "
Expand All @@ -63,26 +90,25 @@ def _connect(connection_factory: Callable[[], Connection]) -> Iterator[Cursor]:
cursor = connection.cursor()
_check_cursor_is_dbapi2_compliant(cursor)
yield cursor

finally:
connection.commit()
except Exception:
# `rollback` is optional since not all databases provide transaction support.
try:
connection.rollback()
except Exception as e:
# Each connector implements its own `NotSupportError` class, so we check
# the exception's name instead of using `isinstance`.
if not (
if (
isinstance(e, AttributeError)
or e.__class__.__name__ == "NotSupportedError"
):
raise e from None

connection.commit()
pass
raise
finally:
connection.close()


class _SQLReader(Reader):

NUM_SAMPLE_ROWS = 100
MIN_ROWS_PER_READ_TASK = 50

Expand Down
15 changes: 6 additions & 9 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,15 +1680,6 @@ def read_sql(
For examples of reading from larger databases like MySQL and PostgreSQL, see
:ref:`Reading from SQL Databases <reading_sql>`.
.. testcode::
:hide:
import os
try:
os.remove("example.db")
except OSError:
pass
.. testcode::
import sqlite3
Expand Down Expand Up @@ -1724,6 +1715,12 @@ def create_connection():
"SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)
.. testcode::
:hide:
import os
os.remove("example.db")
Args:
sql: The SQL query to execute.
connection_factory: A function that takes no arguments and returns a
Expand Down
37 changes: 37 additions & 0 deletions python/ray/data/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,40 @@ def test_read_sql(temp_database: str, parallelism: int):
actual_values = [tuple(record.values()) for record in dataset.take_all()]

assert sorted(actual_values) == sorted(expected_values)


def test_write_sql(temp_database: str):
connection = sqlite3.connect(temp_database)
connection.cursor().execute("CREATE TABLE test(string, number)")
dataset = ray.data.from_items(
[{"string": "spam", "number": 0}, {"string": "ham", "number": 1}]
)

dataset.write_sql(
"INSERT INTO test VALUES(?, ?)", lambda: sqlite3.connect(temp_database)
)

result = connection.cursor().execute("SELECT * FROM test ORDER BY number")
assert result.fetchall() == [("spam", 0), ("ham", 1)]


@pytest.mark.parametrize("num_blocks", (1, 20))
def test_write_sql_many_rows(num_blocks: int, temp_database: str):
connection = sqlite3.connect(temp_database)
connection.cursor().execute("CREATE TABLE test(id)")
dataset = ray.data.range(1000).repartition(num_blocks)

dataset.write_sql(
"INSERT INTO test VALUES(?)", lambda: sqlite3.connect(temp_database)
)

result = connection.cursor().execute("SELECT * FROM test ORDER BY id")
assert result.fetchall() == [(i,) for i in range(1000)]


def test_write_sql_nonexistant_table(temp_database: str):
dataset = ray.data.range(1)
with pytest.raises(sqlite3.OperationalError):
dataset.write_sql(
"INSERT INTO test VALUES(?)", lambda: sqlite3.connect(temp_database)
)

0 comments on commit 2c8036c

Please sign in to comment.