Skip to content

Commit

Permalink
Add "x-fauna-tags" and "traceparent" headers on queries to Fauna (#256)
Browse files Browse the repository at this point in the history
* Add optional params to client.query function

* Add tests

* Regex and test tweaks

* More regex tweaks

* Didn't like those regexes

* Try the compat versions

* String interpolation in 3.5

* Missed one

* More compatible assert

* Comment

* Aligning indents is super-important

* Remove client-side validation and fix tests

* Version and changelog
  • Loading branch information
adambollen committed Feb 13, 2023
1 parent 65b9df0 commit 94f35ec
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 7 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
## 4.3.1 [current]
## 4.4.0 [current]
- Add `tags` and `traceparent` headers. [#256](https://github.com/fauna/faunadb-python/pull/256)

## 4.3.1
- Fix the X-Last-Seen-Txn header. [#250](https://github.com/fauna/faunadb-python/pull/250)
- Fix the changelog link emitted from upgrade prompt. [#249])(https://github.com/fauna/faunadb-python/pull/249)

Expand Down
4 changes: 2 additions & 2 deletions faunadb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = "FaunaDB"
__version__ = "4.3.1"
__version__ = "4.4.0"
__api_version__ = "4"
__author__ = "Fauna, Inc"
__license__ = "MPL 2.0"
__copyright__ = "2020 Fauna, Inc"
__copyright__ = "2023 Fauna, Inc"
22 changes: 19 additions & 3 deletions faunadb/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import platform
import re
import sys
import threading
# pylint: disable=redefined-builtin
Expand Down Expand Up @@ -289,15 +290,15 @@ def __del__(self):
if self.counter.decrement() == 0:
self.session.close()

def query(self, expression, timeout_millis=None):
def query(self, expression, timeout_millis=None, tags=None, traceparent=None):
"""
Use the FaunaDB query API.
:param expression: A query. See :doc:`query` for information on queries.
:param timeout_millis: Query timeout in milliseconds.
:return: Converted JSON response.
"""
return self._execute("POST", "", _wrap(expression), with_txn_time=True, query_timeout_ms=timeout_millis)
return self._execute("POST", "", _wrap(expression), with_txn_time=True, query_timeout_ms=timeout_millis, tags=tags, traceparent=traceparent)

def stream(self, expression, options=None, on_start=None, on_error=None, on_version=None, on_history=None, on_set=None):
"""
Expand Down Expand Up @@ -359,7 +360,7 @@ def new_session_client(self, secret, observer=None):
raise UnexpectedError(
"Cannnot create a session client from a closed session", None)

def _execute(self, action, path, data=None, query=None, with_txn_time=False, query_timeout_ms=None):
def _execute(self, action, path, data=None, query=None, with_txn_time=False, query_timeout_ms=None, tags=None, traceparent=None):
"""Performs an HTTP action, logs it, and looks for errors."""
if query is not None:
query = {k: v for k, v in query.items() if v is not None}
Expand All @@ -372,6 +373,12 @@ def _execute(self, action, path, data=None, query=None, with_txn_time=False, que
if with_txn_time:
headers.update(self._last_txn_time.request_header)

if tags is not None:
headers["x-fauna-tags"] = self._get_tags_string(tags)

if traceparent is not None and self._is_valid_traceparent(traceparent):
headers["traceparent"] = traceparent

start_time = time()
response = self._perform_request(action, path, data, query, headers)
end_time = time()
Expand Down Expand Up @@ -404,3 +411,12 @@ def _perform_request(self, action, path, data, query, headers):
req = Request(action, url, params=query, data=to_json(
data), auth=self.auth, headers=headers)
return self.session.send(self.session.prepare_request(req))

def _get_tags_string(self, tags_dict):
if not isinstance(tags_dict, dict):
raise Exception("Tags must be a dictionary")

return ",".join(["=".join([k, tags_dict[k]]) for k in tags_dict])

def _is_valid_traceparent(self, traceparent):
return bool(re.match("^[\da-f]{2}-[\da-f]{32}-[\da-f]{16}-[\da-f]{2}$", traceparent))
74 changes: 73 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sys
import os
import platform
import random
import string
from faunadb.client import FaunaClient
from faunadb.errors import UnexpectedError
from faunadb.errors import UnexpectedError, BadRequest
from tests.helpers import FaunaTestCase
from faunadb import __version__ as pkg_version, __api_version__ as api_version

Expand Down Expand Up @@ -80,3 +82,73 @@ def test_recognized_runtime_env_headers(self):
)

os.environ["PATH"] = originalPath

def test_tags_header(self):
self.client.observer = lambda rr: self.assertCountEqual(rr.response_headers["x-fauna-tags"].split(","), ["foo=bar", "baz=biz"])
test_tags = {
"foo": "bar",
"baz": "biz",
}
self.client.query({}, tags=test_tags)

self.client.observer = None

def test_invalid_tags_keys(self):
invalid_keys = [
"foo bar",
"foo*bar",
''.join(random.choice(string.ascii_lowercase) for _ in range(41)),
]
for key in invalid_keys:
self.assertRaisesRegexCompat(BadRequest,
"invalid (tags|key)",
lambda: self.client.query({}, tags={ key: "value" }))

def test_invalid_tags_values(self):
invalid_values = [
"foo bar",
"foo*bar",
''.join(random.choice(string.ascii_lowercase) for _ in range(81)),
]
for value in invalid_values:
self.assertRaisesRegexCompat(BadRequest,
"invalid (tags|value)",
lambda: self.client.query({}, tags={ "key": value }))

def test_too_many_tags(self):
too_many_keys = [ (''.join(random.choice(string.ascii_lowercase) for _ in range(10))) for _ in range(30) ]
too_many_tags = { k: "value" for k in too_many_keys }
self.assertRaisesRegexCompat(BadRequest,
"too many tags",
lambda: self.client.query({}, tags=too_many_tags))

def test_traceparent_header(self):
token = ''.join(random.choice(string.hexdigits.lower()) for _ in range(32))
token2 = ''.join(random.choice(string.hexdigits.lower()) for _ in range(16))
req_tp = "00-%s-%s-01"%(token, token2)
self.client.observer = lambda rr: self.assertRegexCompat(rr.response_headers["traceparent"], "^00-%s-\w{16}-\d{2}$"%(token))
self.client.query({}, traceparent=req_tp)

self.client.observer = None

def test_invalid_traceparent_header(self):
self.client.observer = lambda rr: self.assertIsNotNone(rr.response_headers["traceparent"]) and not self.assertRegexCompat(".*foo.*", rr.response_headers["traceparent"])
self.client.query({}, traceparent="foo")

def test_empty_traceparent_header(self):
tp_header = None
tp_part = None

def _test_and_save_traceparent(rr):
self.assertIsNotNone(rr.response_headers["traceparent"])
nonlocal tp_header, tp_part
tp_header = rr.response_headers["traceparent"]
tp_part = tp_header.split('-')[1]

self.client.observer = _test_and_save_traceparent
self.client.query({}, traceparent=None)

self.client.observer = lambda rr: self.assertRegexCompat(rr.response_headers["traceparent"], "^00-%s-\w{16}-\d{2}$"%(tp_part))
self.client.query({}, traceparent=tp_header)

self.client.observer = None

0 comments on commit 94f35ec

Please sign in to comment.