diff --git a/hvac/v1/__init__.py b/hvac/v1/__init__.py index 6c14075a3..78b798bad 100644 --- a/hvac/v1/__init__.py +++ b/hvac/v1/__init__.py @@ -1,4 +1,7 @@ import os +import typing as t + +from warnings import warn from hvac import adapters, api, exceptions, utils from hvac.constants.client import ( @@ -18,6 +21,48 @@ has_hcl_parser = False +# TODO(v4.0.0): remove _sentinel and _smart_pop when write no longer has deprecated behavior: +# https://github.com/hvac/hvac/issues/1034 +_sentinel = object() + + +def _smart_pop( + dict: dict, + member: str, + default: t.Any = _sentinel, + *, + posvalue: t.Any = _sentinel, + method: str = "write", + replacement_method: str = "write_data", +): + try: + value = dict.pop(member) + except KeyError: + if posvalue is not _sentinel: + return posvalue + elif default is not _sentinel: + return default + else: + raise TypeError( + f"{method}() missing one required positional argument: '{member}'" + ) + else: + if posvalue is not _sentinel: + raise TypeError(f"{method}() got multiple values for argument '{member}'") + + warn( + ( + f"{method}() argument '{member}' was supplied as a keyword argument and will not be written as data." + f" To write this data with a '{member}' key, use the {replacement_method}() method." + f" To continue using {method}() and suppress this warning, supply this argument positionally." + f" For more information see: https://github.com/hvac/hvac/issues/1034" + ), + DeprecationWarning, + stacklevel=3, + ) + return value + + class Client: """The hvac Client class for HashiCorp's Vault.""" @@ -251,35 +296,72 @@ def list(self, path): except exceptions.InvalidPath: return None - def write(self, path, wrap_ttl=None, **kwargs): + # TODO(v4.0.0): remove overload when write doesn't use args and kwargs anymore + @t.overload + def write(self, path: str, wrap_ttl: t.Optional[str], **kwargs: t.Dict[str, t.Any]): + pass + + def write(self, *args: list, **kwargs: t.Dict[str, t.Any]): """POST / Write data to a path. Because this method uses kwargs for the data to write, "path" and "wrap_ttl" data keys cannot be used. If these names are needed, or if the key names are not known at design time, consider using the write_data method. :param path: - :type path: + :type path: str :param wrap_ttl: - :type wrap_ttl: + :type wrap_ttl: str | None :param kwargs: - :type kwargs: + :type kwargs: dict :return: :rtype: """ - return self._adapter.post(f"/v1/{path}", json=kwargs, wrap_ttl=wrap_ttl) - def write_data(self, path, *, data={}, wrap_ttl=None): + try: + path = args[0] + except IndexError: + path = _sentinel + + path = _smart_pop(kwargs, "path", posvalue=path) + + try: + wrap_ttl = args[1] + except IndexError: + wrap_ttl = _sentinel + + wrap_ttl = _smart_pop(kwargs, "wrap_ttl", default=None, posvalue=wrap_ttl) + + if "data" in kwargs: + warn( + ( + "write() argument 'data' was supplied as a keyword argument." + " In v3.0.0 the 'data' key will be treated specially. Consider using the write_data() method instead." + " For more information see: https://github.com/hvac/hvac/issues/1034" + ), + PendingDeprecationWarning, + stacklevel=2, + ) + + return self.write_data(path, wrap_ttl=wrap_ttl, data=kwargs) + + def write_data( + self, + path: str, + *, + data: t.Dict[str, t.Any] = {}, + wrap_ttl: t.Optional[str] = None, + ): """Write data to a path. Similar to write() without restrictions on data keys. Supported methods: POST / :param path: - :type path: + :type path: str :param data: - :type dict: + :type data: dict :param wrap_ttl: - :type wrap_ttl: + :type wrap_ttl: str | None :return: :rtype: """ diff --git a/poetry.lock b/poetry.lock index 8801bfa14..191e047c0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -999,6 +999,23 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +[[package]] +name = "pytest-mock" +version = "3.11.1" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-mock-3.11.1.tar.gz", hash = "sha256:7f6b125602ac6d743e523ae0bfa71e1a697a2f5534064528c6ff84c2f7c2fc7f"}, + {file = "pytest_mock-3.11.1-py3-none-any.whl", hash = "sha256:21c279fff83d70763b05f8874cc9cfb3fcacd6d354247a976f9529d19f9acf39"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-ldap-test" version = "0.3.1" @@ -1515,4 +1532,4 @@ parser = ["pyhcl"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "c3c4da7554ffd30640f2d805301f6d39106686af523a2b3652e3167f3da2b4dc" +content-hash = "9e4929bbf116a4f7b50561a7356414449eee106963bfd968d0314198f84a2b9c" diff --git a/pyproject.toml b/pyproject.toml index cde4a3868..a893296ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ jinja2 = "<3.1.0" greenlet = "^3.0.0" jwcrypto = "^1.5.0" typos = "^1.16.11" +pytest-mock = "^3.11.1" [tool.typos.default.extend-words] Hashi = "Hashi" diff --git a/tests/integration_tests/v1/test_system_backend.py b/tests/integration_tests/v1/test_system_backend.py index 0ac505356..358d6953e 100644 --- a/tests/integration_tests/v1/test_system_backend.py +++ b/tests/integration_tests/v1/test_system_backend.py @@ -70,7 +70,7 @@ def test_wrap_write(self): self.client.sys.enable_auth_method("approle") self.client.write("auth/approle/role/testrole") - result = self.client.write( + result = self.client.write_data( "auth/approle/role/testrole/secret-id", wrap_ttl="10s" ) self.assertIn("token", result["wrap_info"]) @@ -383,9 +383,9 @@ def test_tune_auth_backend(self): def test_read_lease(self): # Set up a test pki backend and issue a cert against some role so we. utils.configure_pki(client=self.client) - pki_issue_response = self.client.write( + pki_issue_response = self.client.write_data( path="pki/issue/my-role", - common_name="test.hvac.com", + data=dict(common_name="test.hvac.com"), ) # Read the lease of our test cert that was just issued. diff --git a/tests/unit_tests/v1/test_system_backend_methods.py b/tests/unit_tests/v1/test_system_backend_methods.py index 4576bfca0..c0519c28d 100644 --- a/tests/unit_tests/v1/test_system_backend_methods.py +++ b/tests/unit_tests/v1/test_system_backend_methods.py @@ -1,9 +1,131 @@ -from unittest import TestCase +import pytest +from pytest_mock import MockFixture +from mock import MagicMock +from unittest import TestCase import requests_mock from parameterized import parameterized from hvac import Client +from hvac.v1 import _sentinel, _smart_pop + + +class TestSmartPop: + def test_smart_pop_duplicate(self): + with pytest.raises(TypeError, match=r"got multiple values for argument"): + _smart_pop(dict(a=5), "a", posvalue=9) + + def test_smart_pop_missing(self): + with pytest.raises( + TypeError, match=r"missing one required positional argument" + ): + _smart_pop(dict(a=5), "z") + + @pytest.mark.parametrize("dict", [{}, {"a": 2}]) + @pytest.mark.parametrize("default", [_sentinel, "other"]) + def test_smart_pop_pos_only(self, default, dict, mocker: MockFixture): + result = _smart_pop( + dict, "z", default=default, posvalue=mocker.sentinel.pos_only + ) + assert result is mocker.sentinel.pos_only + assert "z" not in dict + + @pytest.mark.parametrize("dict", [{}, {"a": 2}]) + def test_smart_pop_default_only(self, dict, mocker: MockFixture): + result = _smart_pop(dict, "z", default=mocker.sentinel.default_only) + assert result is mocker.sentinel.default_only + assert "z" not in dict + + @pytest.mark.parametrize("dict", [{"a": 4, "b": 9}, {"a": 2}]) + def test_smart_pop_warns(self, dict): + original = dict.copy() + with pytest.warns( + DeprecationWarning, match=r"https://github.com/hvac/hvac/issues/1034" + ): + result = _smart_pop(dict, "a") + assert result == original["a"] + assert "a" not in dict + + +class TestClientWriteData: + test_url = "https://vault.example.com" + test_path = "whatever/fake" + response = dict(a=1, b="two") + + @pytest.fixture(autouse=True) + def write_mock(self, requests_mock: requests_mock.Mocker): + yield requests_mock.register_uri( + method="POST", + url=f"{self.test_url}/v1/{self.test_path}", + json=self.response, + ) + + @pytest.fixture + def client(self) -> Client: + return Client(url=self.test_url) + + @pytest.mark.parametrize("wrap_ttl", [None, "3m"]) + def test_write_data(self, client: Client, wrap_ttl: str): + response = client.write_data(self.test_path, data="cool", wrap_ttl=wrap_ttl) + assert response == self.response + + +class TestOldClientWrite: + test_url = "https://vault.example.com" + test_path = "whatever/fake" + + @pytest.fixture(autouse=True) + def mock_write_data(self, mocker: MockFixture) -> MagicMock: + yield mocker.patch.object(Client, "write_data") + + @pytest.fixture + def client(self) -> Client: + return Client(url=self.test_url) + + @pytest.mark.parametrize("kwargs", [{}, {"wrap_ttl": "5m"}, {"other": 5}]) + def test_client_write_no_path( + self, + client: Client, + mocker: MockFixture, + kwargs: dict, + mock_write_data: MagicMock, + ): + popper = mocker.patch("hvac.v1._smart_pop", new=mocker.Mock(wraps=_smart_pop)) + with pytest.raises(TypeError): + client.write(**kwargs) + popper.assert_called_once_with(mocker.ANY, "path", posvalue=_sentinel) + mock_write_data.assert_not_called() + + @pytest.mark.parametrize("kwargs", [{}, {"other": 5}]) + def test_client_write_no_wrap_ttl( + self, + client: Client, + mocker: MockFixture, + kwargs: dict, + mock_write_data: MagicMock, + ): + popper = mocker.patch("hvac.v1._smart_pop", new=mocker.Mock(wraps=_smart_pop)) + client.write(self.test_path, **kwargs) + assert popper.call_count == 2 + expected_call = mocker.call( + mocker.ANY, "wrap_ttl", default=None, posvalue=_sentinel + ) + popper.assert_has_calls([expected_call]) + mock_write_data.assert_called_once_with( + self.test_path, wrap_ttl=None, data=kwargs + ) + + def test_client_write_data_field( + self, client: Client, mocker: MockFixture, mock_write_data: MagicMock + ): + with pytest.warns( + PendingDeprecationWarning, + match=r"argument 'data' was supplied as a keyword argument", + ): + client.write(self.test_path, data="thing") + mock_write_data.assert_called_once_with( + self.test_path, wrap_ttl=None, data=dict(data="thing") + ) class TestSystemBackendMethods(TestCase): diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index c7b57a6bc..c0612e47a 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -253,22 +253,28 @@ def configure_pki( client.sys.enable_secrets_engine(backend_type="pki", path=mount_point) - client.write( + client.write_data( path=f"{mount_point}/root/generate/internal", - common_name=common_name, - ttl="8760h", + data=dict( + common_name=common_name, + ttl="8760h", + ), ) - client.write( + client.write_data( path=f"{mount_point}/config/urls", - issuing_certificates="http://127.0.0.1:8200/v1/pki/ca", - crl_distribution_points="http://127.0.0.1:8200/v1/pki/crl", + data=dict( + issuing_certificates="http://127.0.0.1:8200/v1/pki/ca", + crl_distribution_points="http://127.0.0.1:8200/v1/pki/crl", + ), ) - client.write( + client.write_data( path=f"{mount_point}/roles/{role_name}", - allowed_domains=common_name, - allow_subdomains=True, - generate_lease=True, - max_ttl="72h", + data=dict( + allowed_domains=common_name, + allow_subdomains=True, + generate_lease=True, + max_ttl="72h", + ), )