Skip to content

Commit

Permalink
[KED-1475] Fixed bug: SparkDataSet fails reading from DBFS in Windows…
Browse files Browse the repository at this point in the history
… using Databricks connect (kedro-org#525)
  • Loading branch information
andrii-ivaniuk committed Apr 6, 2020
1 parent 44fd0ec commit e996b07
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Documented installation of development version of Kedro in the [FAQ section](https://kedro.readthedocs.io/en/stable/06_resources/01_faq.html#how-can-i-use-development-version-of-kedro).
* Implemented custom glob function for `SparkDataSet` when running on Databricks.
* Added the option for contributors to run Kedro tests locally without Spark installation with `make test-no-spark`.
* Bug in `SparkDataSet` not allowing for loading data from DBFS in a Windows machine using Databricks-connect.

## Breaking changes to the API
* Made `invalidate_cache` method on datasets private.
Expand Down
6 changes: 4 additions & 2 deletions kedro/extras/datasets/spark/spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from copy import deepcopy
from fnmatch import fnmatch
from functools import partial
from pathlib import PurePosixPath
from pathlib import PurePath, PurePosixPath
from typing import Any, Dict, List, Optional, Tuple
from warnings import warn

Expand Down Expand Up @@ -268,9 +268,11 @@ def __init__( # pylint: disable=too-many-arguments
path = PurePosixPath(filepath)

else:
path = PurePosixPath(filepath)
path = PurePath(filepath) # type: ignore

if filepath.startswith("/dbfs"):
# Use PosixPath if the filepath references DBFS
path = PurePosixPath(filepath)
dbutils = _get_dbutils(self._get_spark())
if dbutils:
glob_function = partial(_dbfs_glob, dbutils=dbutils)
Expand Down
20 changes: 19 additions & 1 deletion tests/extras/datasets/spark/test_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# pylint: disable=import-error
import tempfile
from pathlib import Path
from pathlib import Path, PurePosixPath, PureWindowsPath

import pandas as pd
import pytest
Expand Down Expand Up @@ -542,6 +542,24 @@ def test_get_dbutils_no_modules(self, mocker):
mocker.patch.dict("sys.modules", {})
assert _get_dbutils("spark") is None

@pytest.mark.parametrize(
"os_name,path_class", [("nt", PureWindowsPath), ("posix", PurePosixPath)]
)
def test_regular_path_in_different_os(self, os_name, path_class, mocker):
"""Check that class of filepath depends on OS for regular path."""
mocker.patch("os.name", os_name)
data_set = SparkDataSet(filepath="/some/path")
assert isinstance(data_set._filepath, path_class)

@pytest.mark.parametrize(
"os_name,path_class", [("nt", PurePosixPath), ("posix", PurePosixPath)]
)
def test_dbfs_path_in_different_os(self, os_name, path_class, mocker):
"""Check that class of filepath doesn't depend on OS if it references DBFS."""
mocker.patch("os.name", os_name)
data_set = SparkDataSet(filepath="/dbfs/some/path")
assert isinstance(data_set._filepath, path_class)


class TestSparkDataSetVersionedS3:
def test_no_version(self, versioned_dataset_s3):
Expand Down

0 comments on commit e996b07

Please sign in to comment.