Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Append columns #318

Merged
merged 5 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/include/lance/arrow/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class LanceDataset : public ::arrow::dataset::Dataset {
///
/// \param new_field the new field / column to be updated.
/// \return a builder for `Updater`.
::arrow::Result<UpdaterBuilder> NewUpdate(const std::shared_ptr<::arrow::Field>& new_field) const;
::arrow::Result<std::shared_ptr<UpdaterBuilder>> NewUpdate(
const std::shared_ptr<::arrow::Field>& new_field) const;

::arrow::Result<std::shared_ptr<::arrow::dataset::Dataset>> ReplaceSchema(
std::shared_ptr<::arrow::Schema> schema) const override;
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/lance/arrow/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,10 @@ ::arrow::Result<DatasetVersion> LanceDataset::latest_version() const {

DatasetVersion LanceDataset::version() const { return impl_->manifest->GetDatasetVersion(); }

::arrow::Result<UpdaterBuilder> LanceDataset::NewUpdate(
::arrow::Result<std::shared_ptr<UpdaterBuilder>> LanceDataset::NewUpdate(
const std::shared_ptr<::arrow::Field>& new_field) const {
return UpdaterBuilder{std::make_shared<LanceDataset>(*this), std::move(new_field)};
return std::make_shared<UpdaterBuilder>(std::make_shared<LanceDataset>(*this),
std::move(new_field));
}

::arrow::Result<std::shared_ptr<::arrow::dataset::Dataset>> LanceDataset::ReplaceSchema(
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/lance/arrow/updater_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TEST_CASE("Use updater to update one column") {

auto updater = lance_dataset->NewUpdate(::arrow::field("strs", arrow::utf8()))
.ValueOrDie()
.Finish()
->Finish()
.ValueOrDie();
int cnt = 0;
while (true) {
Expand Down
78 changes: 77 additions & 1 deletion python/lance/_lib.pyx
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# distutils: language = c++

from typing import Optional, List, Dict
from typing import Callable, Optional, List, Dict
from pathlib import Path

import pyarrow
from cython.operator cimport dereference as deref
from libcpp cimport bool
from libcpp.memory cimport shared_ptr, static_pointer_cast
Expand All @@ -28,10 +29,17 @@ from pyarrow.includes.libarrow_dataset cimport (
)
from pyarrow.includes.libarrow_fs cimport CFileSystem
from pyarrow.lib cimport (
CArray,
CExpression,
CField,
CRecordBatch,
Field,
GetResultValue,
RecordBatchReader,
check_status,
pyarrow_wrap_batch,
pyarrow_unwrap_field,
pyarrow_unwrap_array,
)
from pyarrow.lib import tobytes
from pyarrow.util import _stringify_path
Expand Down Expand Up @@ -138,6 +146,50 @@ cdef class LanceFileFormat(FileFormat):
return LanceFileWriteOptions.wrap(self.format.DefaultWriteOptions())


cdef extern from "lance/arrow/updater.h" namespace "lance::arrow" nogil:
cdef cppclass CUpdater "::lance::arrow::Updater":
CResult[shared_ptr[CRecordBatch]] Next();

CStatus UpdateBatch(const shared_ptr[CArray] arr);

CResult[shared_ptr[CLanceDataset]] Finish();

cdef cppclass CUpdaterBuilder "::lance::arrow::UpdaterBuilder":
CResult[shared_ptr[CUpdater]] Finish();


cdef class Updater:
cdef shared_ptr[CUpdater] sp_updater

@staticmethod
cdef wrap(const shared_ptr[CUpdater]& up):
cdef Updater self = Updater.__new__(Updater)
self.sp_updater = move(up)
return self

def update_batch(self, data: pyarrow.Array):
cdef shared_ptr[CArray] arr = pyarrow_unwrap_array(data)
with nogil:
check_status(self.sp_updater.get().UpdateBatch(arr))

def finish(self):
cdef shared_ptr[CLanceDataset] new_dataset = GetResultValue(
self.sp_updater.get().Finish()
)
return FileSystemDataset.wrap(static_pointer_cast[CDataset, CLanceDataset](new_dataset))

def __iter__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need both iter and next?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean next() or __next__() here?

return self

def __next__(self) -> pyarrow.Table:
cdef shared_ptr[CRecordBatch] c_batch
c_batch = GetResultValue(self.sp_updater.get().Next())
if c_batch.get() == NULL:
raise StopIteration
batch = pyarrow_wrap_batch(c_batch)
return pyarrow.Table.from_batches([batch])


cdef extern from "lance/arrow/dataset.h" namespace "lance::arrow" nogil:
cdef cppclass CDatasetVersion "::lance::arrow::DatasetVersion":
uint64_t version() const;
Expand Down Expand Up @@ -167,6 +219,7 @@ cdef extern from "lance/arrow/dataset.h" namespace "lance::arrow" nogil:

CResult[vector[CDatasetVersion]] versions() const;

CResult[shared_ptr[CUpdaterBuilder]] NewUpdate(const shared_ptr[CField]& field) const;

cdef _dataset_version_to_json(CDatasetVersion cdv):
return {
Expand Down Expand Up @@ -220,6 +273,29 @@ cdef class FileSystemDataset(Dataset):
c_version = GetResultValue(self.lance_dataset.latest_version())
return _dataset_version_to_json(c_version)

def append_column(self, field: Field, func: Callable[[pyarrow.Table], pyarrow.Array]) -> FileSystemDataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User needs to know that func input is just a chunk of data. Eg they can't do something like normalization in this function unless they explicitly compute the mean / std globally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k, will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a good argument of whether func should take in a RecordBatch instead of Table

"""Append a new column.

Parameters
----------
field : pyarrow.Field
The name and schema of the newly added column.
func : Callback[[pyarrow.Table], pyarrow.Array]
A function / callback that takes in a Batch and produces an Array. The generated array must
have the same length as the input batch.
"""
cdef:
shared_ptr[CUpdater] c_updater
shared_ptr[CField] c_field

c_field = pyarrow_unwrap_field(field)
c_updater = move(GetResultValue(GetResultValue(move(self.lance_dataset.NewUpdate(c_field))).get().Finish()))
updater = Updater.wrap(c_updater)
for table in updater:
arr = func(table)
updater.update_batch(arr)
return updater.finish()

def _lance_dataset_write(
Dataset data,
object base_dir not None,
Expand Down
32 changes: 32 additions & 0 deletions python/lance/tests/test_schema_evolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2022. Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import lance
import pyarrow as pa
import pandas as pd


def test_write_versioned_dataset(tmp_path: Path):
table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}])
base_dir = tmp_path / "test"
lance.write_dataset(table1, base_dir)

dataset = lance.dataset(base_dir)
new_dataset = dataset.append_column(pa.field("c", pa.utf8()), lambda x: pa.array([f"a{i}" for i in range(len(x))]))

actual_df = new_dataset.to_table().to_pandas()
expected_df = pd.DataFrame({"a": [1, 10], "b": [2, 20], "c": ["a0", "a1"]})
pd.testing.assert_frame_equal(expected_df, actual_df)