Skip to content

Commit

Permalink
add pydocstring linting
Browse files Browse the repository at this point in the history
  • Loading branch information
LGro committed May 26, 2022
1 parent 0ede1e0 commit bd66ab4
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 47 deletions.
2 changes: 2 additions & 0 deletions apsi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
"""Python wrapper for labeled and unlabeled APSI."""

from .clients import UnlabeledClient, LabeledClient
from .servers import UnlabeledServer, LabeledServer
44 changes: 28 additions & 16 deletions apsi/clients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""(Un-)labeled APSI client implementations."""

from typing import Dict, List

from _pyapsi import APSIClient as _Client
Expand All @@ -7,16 +9,18 @@ class _BaseClient(_Client):
queried_items: List[str]

def oprf_request(self, item: str) -> bytes:
"""Create an OPRF request for a given item. This is the first step when querying
a server for items.
"""Create an OPRF request for a given item.
This is the first step when querying a server for items.
"""
# TODO: Switch to a request multi with items: List[str]
self.queried_items = (item,)
return self._oprf_request(item)

def build_query(self, oprf_response: bytes) -> bytes:
"""Build a query based on the server's response to an initial OPRF request. This
is the second step when querying for items.
"""Build a query based on the server's response to an initial OPRF request.
This is the second step when querying for items.
"""
if not self.queried_items:
raise RuntimeError("You need to create an OPRF request first.")
Expand All @@ -25,13 +29,16 @@ def build_query(self, oprf_response: bytes) -> bytes:


class LabeledClient(_BaseClient):
def __init__(self, params_json: str):
"""A client for labeled asynchronous private set intersection (APSI).
"""A client for labeled asynchronous private set intersection (APSI).
For a complete query, use the client interface in the following order:
1. `oprf_request`
2. `build_query`
3. `extract_result`
"""

For a complete query, use the client interface in the following order:
1. `oprf_request`
2. `build_query`
3. `extract_result`
def __init__(self, params_json: str):
"""Initialize a client for labeled APSI.
Args:
params_json: The JSON string representation of APSI/SEAL parameters
Expand All @@ -40,6 +47,7 @@ def __init__(self, params_json: str):

def extract_result(self, query_response: bytes) -> Dict[str, str]:
"""Extract the resulting item, label pairs from the server's query response.
This is the final step when querying for items.
Returns:
Expand All @@ -53,13 +61,16 @@ def extract_result(self, query_response: bytes) -> Dict[str, str]:


class UnlabeledClient(_BaseClient):
def __init__(self, params_json: str):
"""A client for unlabeled asynchronous private set intersection (APSI).
"""A client for unlabeled asynchronous private set intersection (APSI).
For a complete query, use the client interface in the following order:
1. `oprf_request`
2. `build_query`
3. `extract_result`
For a complete query, use the client interface in the following order:
1. `oprf_request`
2. `build_query`
3. `extract_result`
"""

def __init__(self, params_json: str):
"""Initialize a client for unlabeled APSI.
Args:
params_json: The JSON string representation of APSI/SEAL parameters
Expand All @@ -68,6 +79,7 @@ def __init__(self, params_json: str):

def extract_result(self, query_response: bytes) -> List[str]:
"""Extract the matched items from the server's query response.
This is the final step when querying for items.
"""
matches = super()._extract_unlabeled_result_from_query_response(query_response)
Expand Down
62 changes: 34 additions & 28 deletions apsi/servers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""(Un-)labeled APSI server implementations."""

from typing import Iterable, Tuple

from _pyapsi import APSIServer as _Server
Expand All @@ -15,9 +17,7 @@ def _requires_db(self):
raise RuntimeError("Please initialize or load a database first.")

def save_db(self, db_file_path: str) -> None:
"""Save the database in unencrypted binary representation at the given file
path.
"""
"""Save the database in unencrypted binary representation at the given path."""
self._requires_db()
self._save_db(db_file_path)

Expand All @@ -27,27 +27,32 @@ def load_db(self, db_file_path: str) -> None:
self.db_initialized = True

def handle_oprf_request(self, oprf_request: bytes) -> bytes:
"""Handle an initial APSI Client OPRF request and return a compatible bytes
response that can be used by the client to create the main query.
"""Handle an initial APSI Client OPRF request.
The returned bytes response can be used by the client to create the main query.
"""
self._requires_db()
return self._handle_oprf_request(oprf_request)

def handle_query(self, query: bytes) -> bytes:
"""Handle an APSI Client query following up on an initial OPRF request and
return the encrypted query response in an APSI Client compatible byte string.
"""Handle an APSI Client query.
This step follows after an initial OPRF request and returns the encrypted query
response in an APSI Client compatible byte string.
"""
self._requires_db()
return self._handle_query(query)


class LabeledServer(_BaseServer):
def __init__(self):
"""A server for labeled asynchronous private set intersection (APSI).
"""A server for labeled asynchronous private set intersection (APSI).
For this server to do something meaningful, initialize an empty database with
`init_db` or load an existing one with `load_db`.
"""
For this server to do something meaningful, initialize an empty database with
`init_db` or load an existing one with `load_db`.
"""

def __init__(self):
"""Initialize a labeled APSI server."""
super().__init__()

def init_db(
Expand All @@ -72,16 +77,19 @@ def init_db(
self.db_initialized = True

def add_item(self, item: str, label: str) -> None:
"""Add an item with a label to the server's database so that the item can be
queried by a client to learn about the label. If one considers this a key value
store, the item is the key and the value is the label.
"""Add an item with a label to the server.
The item can then be queried by a client to learn about the label. If one
considers this a key value store, the item is the key and the value is the
label.
"""
self._requires_db()
self._add_item(item, label)

def add_items(self, items_with_label: Iterable[Tuple[str, str]]) -> None:
"""Add multiple pairs of item and corresponding label to the server's database
so that any of the items can be queried by a client to learn about the matching
"""Add multiple pairs of item and corresponding label to the server.
Any of the items can then be queried by a client to learn about the matching
label. If one considers this a key value store, the item is the key and the
value is the label.
"""
Expand All @@ -92,12 +100,14 @@ def add_items(self, items_with_label: Iterable[Tuple[str, str]]) -> None:


class UnlabeledServer(_BaseServer):
def __init__(self):
"""A server for unlabeled asynchronous private set intersection (APSI).
"""A server for unlabeled asynchronous private set intersection (APSI).
For this server to do something meaningful, initialize an empty database with
`init_db` or load an existing one with `load_db`.
"""
For this server to do something meaningful, initialize an empty database with
`init_db` or load an existing one with `load_db`.
"""

def __init__(self):
"""Initialize an unlabled APSI server."""
super().__init__()

def init_db(self, params_json: str, compressed: bool = False) -> None:
Expand All @@ -112,16 +122,12 @@ def init_db(self, params_json: str, compressed: bool = False) -> None:
self.db_initialized = True

def add_item(self, item: str) -> None:
"""Add an item to the server's database so that the item can be queried by a
client.
"""
"""Add item to the server so that it can be queried by a client."""
self._requires_db()
self._add_item(item, "")

def add_items(self, items: Iterable[str]) -> None:
"""Add multiple items to the server's database so that they can be queried by a
client.
"""
"""Add multiple items to the server so that a client can query them."""
self._requires_db()
# TODO: Expose batch add in C++ PyAPSI
for item in items:
Expand Down
7 changes: 5 additions & 2 deletions apsi/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utility functions for multi threading and logging."""

from _pyapsi.utils import (
_set_thread_count,
_get_thread_count,
Expand All @@ -8,8 +10,9 @@


def set_thread_count(thread_count: int) -> None:
"""Set the global APSI thread count, which allows parallelization of some
operations to improve the runtime performance.
"""Set the global APSI thread count.
Allows parallelization of some operations to improve the runtime performance.
"""
if thread_count < 1 or not isinstance(thread_count, int):
raise ValueError(
Expand Down
44 changes: 43 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ pybind11 = "^2.9.2"
black = "^22.3.0"
isort = "^5.10.1"
pytest = "^7.1.2"
pydocstyle = "^6.1.1"
toml = "^0.10.2"

[tool.pydocstyle]
match-dir = "apsi"
convention = "google"

[build-system]
requires = [
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Setuptools integration of C++ APSI and Python package."""

import os
from glob import glob

Expand Down

0 comments on commit bd66ab4

Please sign in to comment.