Skip to content

Commit

Permalink
Firestore: add support for CollectionGroup queries. (apache#7758)
Browse files Browse the repository at this point in the history
* Initial plumbing for collection group queries

* Don't assume direct children for collection group queries.

* trim document path to DocumentReference

* add unit tests

* ensure all_descendants is set after calling other query methods

* port test for node impl

* port tests from node impl

* Fix collection group test on Python 2.7.

Blacken.

* Use '_all_descendants' in 'Query.__eq__'.

* Ensure '_all_descendants' propagates when narrowing query.

* Refactor collection group system tests.

Skip the one for 'where', because it requires a custom index.

* Match node test's collection group ID / expected output.

See:
https://github.com/googleapis/nodejs-firestore/pull/578/files#diff-6b8febc8d51ea01205628091b3611eacR1188

* Match Node test for filter on '__name__'.

Note that this still doesn't pass, so remains skipped.

* Blacken.

* Fix / unskip systest for collection groups w/ filter on '__name__'.

* Blacken

* 100% coverage.

* Lint
  • Loading branch information
crwilcox committed Apr 30, 2019
1 parent 2012f59 commit 06df3ad
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 39 deletions.
33 changes: 33 additions & 0 deletions google/cloud/firestore_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.cloud.client import ClientWithProject

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1 import query
from google.cloud.firestore_v1 import types
from google.cloud.firestore_v1.batch import WriteBatch
from google.cloud.firestore_v1.collection import CollectionReference
Expand Down Expand Up @@ -179,6 +180,31 @@ def collection(self, *collection_path):

return CollectionReference(*path, client=self)

def collection_group(self, collection_id):
"""
Creates and returns a new Query that includes all documents in the
database that are contained in a collection or subcollection with the
given collection_id.
.. code-block:: python
>>> query = firestore.collection_group('mygroup')
@param {string} collectionId Identifies the collections to query over.
Every collection or subcollection with this ID as the last segment of its
path will be included. Cannot contain a slash.
@returns {Query} The created Query.
"""
if "/" in collection_id:
raise ValueError(
"Invalid collection_id "
+ collection_id
+ ". Collection IDs must not contain '/'."
)

collection = self.collection(collection_id)
return query.Query(collection, all_descendants=True)

def document(self, *document_path):
"""Get a reference to a document in a collection.
Expand Down Expand Up @@ -215,6 +241,13 @@ def document(self, *document_path):
else:
path = document_path

# DocumentReference takes a relative path. Strip the database string if present.
base_path = self._database_string + "/documents/"
joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path)
if joined_path.startswith(base_path):
joined_path = joined_path[len(base_path) :]
path = joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER)

return DocumentReference(*path, client=self)

@staticmethod
Expand Down
55 changes: 51 additions & 4 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class Query(object):
any matching documents will be included in the result set.
When the query is formed, the document values
will be used in the order given by ``orders``.
all_descendants (Optional[bool]): When false, selects only collections
that are immediate children of the `parent` specified in the
containing `RunQueryRequest`. When true, selects all descendant
collections.
"""

ASCENDING = "ASCENDING"
Expand All @@ -128,6 +132,7 @@ def __init__(
offset=None,
start_at=None,
end_at=None,
all_descendants=False,
):
self._parent = parent
self._projection = projection
Expand All @@ -137,6 +142,7 @@ def __init__(
self._offset = offset
self._start_at = start_at
self._end_at = end_at
self._all_descendants = all_descendants

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand All @@ -150,6 +156,7 @@ def __eq__(self, other):
and self._offset == other._offset
and self._start_at == other._start_at
and self._end_at == other._end_at
and self._all_descendants == other._all_descendants
)

@property
Expand Down Expand Up @@ -203,6 +210,7 @@ def select(self, field_paths):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def where(self, field_path, op_string, value):
Expand Down Expand Up @@ -270,6 +278,7 @@ def where(self, field_path, op_string, value):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

@staticmethod
Expand Down Expand Up @@ -321,6 +330,7 @@ def order_by(self, field_path, direction=ASCENDING):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def limit(self, count):
Expand All @@ -346,6 +356,7 @@ def limit(self, count):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def offset(self, num_to_skip):
Expand All @@ -372,6 +383,7 @@ def offset(self, num_to_skip):
offset=num_to_skip,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def _cursor_helper(self, document_fields, before, start):
Expand Down Expand Up @@ -418,6 +430,7 @@ def _cursor_helper(self, document_fields, before, start):
"orders": self._orders,
"limit": self._limit,
"offset": self._offset,
"all_descendants": self._all_descendants,
}
if start:
query_kwargs["start_at"] = cursor_pair
Expand Down Expand Up @@ -679,7 +692,7 @@ def _to_protobuf(self):
"select": projection,
"from": [
query_pb2.StructuredQuery.CollectionSelector(
collection_id=self._parent.id
collection_id=self._parent.id, all_descendants=self._all_descendants
)
],
"where": self._filters_pb(),
Expand Down Expand Up @@ -739,9 +752,14 @@ def stream(self, transaction=None):
)

for response in response_iterator:
snapshot = _query_response_to_snapshot(
response, self._parent, expected_prefix
)
if self._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
response, self._parent
)
else:
snapshot = _query_response_to_snapshot(
response, self._parent, expected_prefix
)
if snapshot is not None:
yield snapshot

Expand Down Expand Up @@ -968,3 +986,32 @@ def _query_response_to_snapshot(response_pb, collection, expected_prefix):
update_time=response_pb.document.update_time,
)
return snapshot


def _collection_group_query_response_to_snapshot(response_pb, collection):
"""Parse a query response protobuf to a document snapshot.
Args:
response_pb (google.cloud.proto.firestore.v1.\
firestore_pb2.RunQueryResponse): A
collection (~.firestore_v1.collection.CollectionReference): A
reference to the collection that initiated the query.
Returns:
Optional[~.firestore.document.DocumentSnapshot]: A
snapshot of the data returned in the query. If ``response_pb.document``
is not set, the snapshot will be :data:`None`.
"""
if not response_pb.HasField("document"):
return None
reference = collection._client.document(response_pb.document.name)
data = _helpers.decode_dict(response_pb.document.fields, collection._client)
snapshot = document.DocumentSnapshot(
reference,
data,
exists=True,
read_time=response_pb.read_time,
create_time=response_pb.document.create_time,
update_time=response_pb.document.update_time,
)
return snapshot
114 changes: 114 additions & 0 deletions tests/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,120 @@ def test_query_unary(client, cleanup):
assert math.isnan(data1[field_name])


def test_collection_group_queries(client, cleanup):
collection_group = "b" + unique_resource_id("-")

doc_paths = [
"abc/123/" + collection_group + "/cg-doc1",
"abc/123/" + collection_group + "/cg-doc2",
collection_group + "/cg-doc3",
collection_group + "/cg-doc4",
"def/456/" + collection_group + "/cg-doc5",
collection_group + "/virtual-doc/nested-coll/not-cg-doc",
"x" + collection_group + "/not-cg-doc",
collection_group + "x/not-cg-doc",
"abc/123/" + collection_group + "x/not-cg-doc",
"abc/123/x" + collection_group + "/not-cg-doc",
"abc/" + collection_group,
]

batch = client.batch()
for doc_path in doc_paths:
doc_ref = client.document(doc_path)
batch.set(doc_ref, {"x": 1})

batch.commit()

query = client.collection_group(collection_group)
snapshots = list(query.stream())
found = [snapshot.id for snapshot in snapshots]
expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"]
assert found == expected


def test_collection_group_queries_startat_endat(client, cleanup):
collection_group = "b" + unique_resource_id("-")

doc_paths = [
"a/a/" + collection_group + "/cg-doc1",
"a/b/a/b/" + collection_group + "/cg-doc2",
"a/b/" + collection_group + "/cg-doc3",
"a/b/c/d/" + collection_group + "/cg-doc4",
"a/c/" + collection_group + "/cg-doc5",
collection_group + "/cg-doc6",
"a/b/nope/nope",
]

batch = client.batch()
for doc_path in doc_paths:
doc_ref = client.document(doc_path)
batch.set(doc_ref, {"x": doc_path})

batch.commit()

query = (
client.collection_group(collection_group)
.order_by("__name__")
.start_at([client.document("a/b")])
.end_at([client.document("a/b0")])
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"])

query = (
client.collection_group(collection_group)
.order_by("__name__")
.start_after([client.document("a/b")])
.end_before([client.document("a/b/" + collection_group + "/cg-doc3")])
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2"])


def test_collection_group_queries_filters(client, cleanup):
collection_group = "b" + unique_resource_id("-")

doc_paths = [
"a/a/" + collection_group + "/cg-doc1",
"a/b/a/b/" + collection_group + "/cg-doc2",
"a/b/" + collection_group + "/cg-doc3",
"a/b/c/d/" + collection_group + "/cg-doc4",
"a/c/" + collection_group + "/cg-doc5",
collection_group + "/cg-doc6",
"a/b/nope/nope",
]

batch = client.batch()

for index, doc_path in enumerate(doc_paths):
doc_ref = client.document(doc_path)
batch.set(doc_ref, {"x": index})

batch.commit()

query = (
client.collection_group(collection_group)
.where("__name__", ">=", client.document("a/b"))
.where("__name__", "<=", client.document("a/b0"))
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"])

query = (
client.collection_group(collection_group)
.where("__name__", ">", client.document("a/b"))
.where(
"__name__", "<", client.document("a/b/{}/cg-doc3".format(collection_group))
)
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2"])


def test_get_all(client, cleanup):
collection_name = "get-all" + unique_resource_id("-")

Expand Down
30 changes: 29 additions & 1 deletion tests/unit/v1/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ def test_collection_factory_nested(self):
self.assertIs(collection2._client, client)
self.assertIsInstance(collection2, CollectionReference)

def test_collection_group(self):
client = self._make_default_one()
query = client.collection_group("collectionId").where("foo", "==", u"bar")

assert query._all_descendants
assert query._field_filters[0].field.field_path == "foo"
assert query._field_filters[0].value.string_value == u"bar"
assert query._field_filters[0].op == query._field_filters[0].EQUAL
assert query._parent.id == "collectionId"

def test_collection_group_no_slashes(self):
client = self._make_default_one()
with self.assertRaises(ValueError):
client.collection_group("foo/bar")

def test_document_factory(self):
from google.cloud.firestore_v1.document import DocumentReference

Expand All @@ -148,7 +163,20 @@ def test_document_factory(self):
self.assertIs(document2._client, client)
self.assertIsInstance(document2, DocumentReference)

def test_document_factory_nested(self):
def test_document_factory_w_absolute_path(self):
from google.cloud.firestore_v1.document import DocumentReference

parts = ("rooms", "roomA")
client = self._make_default_one()
doc_path = "/".join(parts)
to_match = client.document(doc_path)
document1 = client.document(to_match._document_path)

self.assertEqual(document1._path, parts)
self.assertIs(document1._client, client)
self.assertIsInstance(document1, DocumentReference)

def test_document_factory_w_nested_path(self):
from google.cloud.firestore_v1.document import DocumentReference

client = self._make_default_one()
Expand Down
Loading

0 comments on commit 06df3ad

Please sign in to comment.