Skip to content

Commit

Permalink
[SPARK-47137][PYTHON][CONNECT] Add getAll to spark.conf for feature p…
Browse files Browse the repository at this point in the history
…arity with Scala

### What changes were proposed in this pull request?

Adds `getAll` to `spark.conf` for feature parity with Scala.

```py
>>> spark.conf.getAll
{'spark.sql.warehouse.dir': ...}
```

### Why are the changes needed?

Scala API provides `spark.conf.getAll`; whereas Python doesn't.

```scala
scala> spark.conf.getAll
val res0: Map[String,String] = HashMap(spark.sql.warehouse.dir -> ...
```

### Does this PR introduce _any_ user-facing change?

Yes, `spark.conf.getAll` will be available in PySpark.

### How was this patch tested?

Added the related tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45222 from ueshin/issues/SPARK-47137/getAll.

Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
ueshin authored and dongjoon-hyun committed Feb 23, 2024
1 parent b90514c commit 511839b
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 25 deletions.
16 changes: 15 additions & 1 deletion python/pyspark/sql/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

import sys
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Union

from py4j.java_gateway import JavaObject

Expand Down Expand Up @@ -93,6 +93,20 @@ def get(
self._check_type(default, "default")
return self._jconf.get(key, default)

@property
def getAll(self) -> Dict[str, str]:
"""
Returns all properties set in this conf.
.. versionadded:: 4.0.0
Returns
-------
dict
A dictionary containing all properties set in this conf.
"""
return dict(self._jconf.getAllAsJava())

def unset(self, key: str) -> None:
"""
Resets the configuration property for the given key.
Expand Down
15 changes: 14 additions & 1 deletion python/pyspark/sql/connect/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

check_dependencies(__name__)

from typing import Any, Optional, Union, cast
from typing import Any, Dict, Optional, Union, cast
import warnings

from pyspark import _NoValue
Expand Down Expand Up @@ -68,6 +68,19 @@ def get(

get.__doc__ = PySparkRuntimeConfig.get.__doc__

@property
def getAll(self) -> Dict[str, str]:
op_get_all = proto.ConfigRequest.GetAll()
operation = proto.ConfigRequest.Operation(get_all=op_get_all)
result = self._client.config(operation)
confs: Dict[str, str] = dict()
for key, value in result.pairs:
assert value is not None
confs[key] = value
return confs

getAll.__doc__ = PySparkRuntimeConfig.getAll.__doc__

def unset(self, key: str) -> None:
op_unset = proto.ConfigRequest.Unset(keys=[key])
operation = proto.ConfigRequest.Operation(unset=op_unset)
Expand Down
63 changes: 40 additions & 23 deletions python/pyspark/sql/tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,49 @@ def test_conf(self):
def test_conf_with_python_objects(self):
spark = self.spark

for value, expected in [(True, "true"), (False, "false")]:
spark.conf.set("foo", value)
self.assertEqual(spark.conf.get("foo"), expected)

spark.conf.set("foo", 1)
self.assertEqual(spark.conf.get("foo"), "1")

with self.assertRaises(IllegalArgumentException):
spark.conf.set("foo", None)

with self.assertRaises(Exception):
spark.conf.set("foo", Decimal(1))
try:
for value, expected in [(True, "true"), (False, "false")]:
spark.conf.set("foo", value)
self.assertEqual(spark.conf.get("foo"), expected)

spark.conf.set("foo", 1)
self.assertEqual(spark.conf.get("foo"), "1")

with self.assertRaises(IllegalArgumentException):
spark.conf.set("foo", None)

with self.assertRaises(Exception):
spark.conf.set("foo", Decimal(1))

with self.assertRaises(PySparkTypeError) as pe:
spark.conf.get(123)

self.check_error(
exception=pe.exception,
error_class="NOT_STR",
message_parameters={
"arg_name": "key",
"arg_type": "int",
},
)
finally:
spark.conf.unset("foo")

def test_get_all(self):
spark = self.spark
all_confs = spark.conf.getAll

with self.assertRaises(PySparkTypeError) as pe:
spark.conf.get(123)
self.assertTrue(len(all_confs) > 0)
self.assertNotIn("foo", all_confs)

self.check_error(
exception=pe.exception,
error_class="NOT_STR",
message_parameters={
"arg_name": "key",
"arg_type": "int",
},
)
try:
spark.conf.set("foo", "bar")
updated = spark.conf.getAll

spark.conf.unset("foo")
self.assertEquals(len(updated), len(all_confs) + 1)
self.assertIn("foo", updated)
finally:
spark.conf.unset("foo")


class ConfTests(ConfTestsMixin, ReusedSQLTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import scala.jdk.CollectionConverters._

import org.apache.spark.SPARK_DOC_ROOT
import org.apache.spark.annotation.Stable
import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry}
Expand Down Expand Up @@ -118,6 +120,10 @@ class RuntimeConfig private[sql](val sqlConf: SQLConf = new SQLConf) {
sqlConf.getAllConfs
}

private[sql] def getAllAsJava: java.util.Map[String, String] = {
getAll.asJava
}

/**
* Returns the value of Spark runtime configuration property for the given key.
*
Expand Down

0 comments on commit 511839b

Please sign in to comment.