Skip to content

Commit

Permalink
Implement support for custom_headers. (#94)
Browse files Browse the repository at this point in the history
* Implement support for custom_headers.

* Fix NoneType problem

* Fix validation
  • Loading branch information
dkulic committed Mar 13, 2023
1 parent 47d6428 commit ec4e806
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 36 deletions.
2 changes: 1 addition & 1 deletion didcomm/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict, Any, Union, List

JSON_OBJ = Dict[str, Any]
JSON_VALUE = Union[None, str, int, bool, float, JSON_OBJ, List[Any]]
JSON_VALUE = Union[type(None), str, int, bool, float, Dict, List]
JSON = str
JWK = JSON
JWT = JSON
Expand Down
61 changes: 37 additions & 24 deletions didcomm/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,26 @@
DIDCommValueError,
)

NoneType = type(None)
HeaderKeyType = str
HeaderValueType = JSON_VALUE
Header = Dict[HeaderKeyType, HeaderValueType]
Headers = Dict[HeaderKeyType, HeaderValueType]
T = TypeVar("T")
_MESSAGE_DEFAULT_FIELDS = {
"id",
"type",
_MESSAGE_RESERVED_FIELDS = {
"ack",
"attachments",
"body",
"frm",
"to",
"created_time",
"expires_time",
"frm",
"from",
"from_prior",
"id",
"please_ack",
"ack",
"thid",
"pthid",
"attachments",
"thid",
"to",
"type",
}


Expand Down Expand Up @@ -333,19 +335,18 @@ class GenericMessage(Generic[T]):
),
default=None,
)
custom_headers: Optional[List[Header]] = attr.ib(
custom_headers: Optional[Headers] = attr.ib(
validator=validator__optional(
validator__deep_iterable(
validator__deep_mapping(
key_validator=validator__and_(
validator__instance_of(HeaderKeyType),
validator__not_in_(_MESSAGE_DEFAULT_FIELDS),
),
value_validator=validator__instance_of(HeaderValueType),
mapping_validator=validator__instance_of(Dict),
validator__deep_mapping(
key_validator=validator__and_(
validator__instance_of(HeaderKeyType),
validator__not_in_(_MESSAGE_RESERVED_FIELDS),
),
iterable_validator=validator__instance_of(List),
)
value_validator=validator__instance_of(
(str, int, bool, float, Dict, List, NoneType)
),
mapping_validator=validator__instance_of(Dict),
),
),
default=None,
)
Expand All @@ -364,6 +365,11 @@ def _body_as_dict(self):
return self.body

def as_dict(self) -> dict:
try:
attr.validate(self)
except Exception as exc:
raise DIDCommValueError(str(exc)) from exc

d = attrs_to_dict(self)

d["body"] = self._body_as_dict()
Expand All @@ -379,10 +385,9 @@ def as_dict(self) -> dict:
if self.from_prior:
d["from_prior"] = self.from_prior.as_dict()

try:
attr.validate(self)
except Exception as exc:
raise DIDCommValueError(str(exc)) from exc
if self.custom_headers:
del d["custom_headers"]
d.update(self.custom_headers)

return d

Expand Down Expand Up @@ -413,6 +418,14 @@ def from_dict(cls, d: dict) -> Message:
raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT)
del d["typ"]

custom_header_keys = d.keys() - _MESSAGE_RESERVED_FIELDS
if custom_header_keys:
custom_headers = {}
for key in custom_header_keys:
custom_headers[key] = d[key]
del d[key]
d["custom_headers"] = custom_headers

if "from" in d:
d["frm"] = d["from"]
del d["from"]
Expand Down
4 changes: 2 additions & 2 deletions didcomm/pack_encrypted.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from didcomm.did_doc.did_doc import DIDCommService
from didcomm.errors import DIDCommValueError
from didcomm.core.from_prior import pack_from_prior_in_place
from didcomm.message import Message, Header
from didcomm.message import Message, Headers
from didcomm.protocols.routing.forward import (
wrap_in_forward,
resolve_did_services_chain,
Expand Down Expand Up @@ -261,7 +261,7 @@ class PackEncryptedParameters:
in the packed message
"""

forward_headers: Optional[Header] = None
forward_headers: Optional[Headers] = None
forward_service_id: Optional[str] = None
forward_didcomm_id_generator: Optional[
DIDCommGeneratorType
Expand Down
4 changes: 2 additions & 2 deletions didcomm/protocols/routing/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from didcomm.common.resolvers import ResolversConfig
from didcomm.common.algorithms import AnonCryptAlg
from didcomm.message import GenericMessage, Header, Attachment, AttachmentDataJson
from didcomm.message import GenericMessage, Headers, Attachment, AttachmentDataJson
from didcomm.core.types import EncryptResult, DIDCommGeneratorType, DIDCOMM_ORG_DOMAIN
from didcomm.core.defaults import DEF_ENC_ALG_ANON
from didcomm.core.validators import (
Expand Down Expand Up @@ -209,7 +209,7 @@ async def wrap_in_forward(
to: DID_OR_DID_URL,
routing_keys: List[DID_OR_DID_URL],
enc_alg_anon: Optional[AnonCryptAlg] = DEF_ENC_ALG_ANON,
headers: Optional[Header] = None,
headers: Optional[Headers] = None,
didcomm_id_generator: Optional[DIDCommGeneratorType] = None,
) -> Optional[ForwardPackResult]:
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/test_vectors/didcomm_messages/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ def minimal_msg():
)


def custom_headers_msg():
msg = copy.deepcopy(TEST_MESSAGE)
msg.custom_headers = {
"my_string": "string value",
"my_int": 123,
"my_bool": False,
"my_float": 1.23,
"my_json": {"key": "value"},
"my_list": [1, 2, 3],
"my_none": None,
}
return msg


def attachment_base64_msg():
msg = copy.deepcopy(TEST_MESSAGE)
msg.attachments = [Attachment(id="23", data=AttachmentDataBase64(base64="qwerty"))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@
}
)

TEST_PLAINTEXT_DIDCOMM_MESSAGE_WITH_CUSTOM_HEADERS = json_dumps(
{
"id": "1234567890",
"thid": "1234567890",
"typ": "application/didcomm-plain+json",
"type": "http:https://example.com/protocols/lets_do_lunch/1.0/proposal",
"from": "did:example:alice",
"to": ["did:example:bob"],
"created_time": 1516269022,
"expires_time": 1516385931,
"body": {"messagespecificattribute": "and its value"},
"my_string": "string value",
"my_int": 123,
"my_bool": False,
"my_float": 1.23,
"my_json": {"key": "value"},
"my_list": [1, 2, 3],
"my_none": None,
}
)

TEST_PLAINTEXT_ATTACHMENT_BASE64 = json_dumps(
{
"id": "1234567890",
Expand Down
19 changes: 12 additions & 7 deletions tests/unit/pack_common/test_message_negative.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ async def test_no_required_param(resolvers_config_alice):
@pytest.mark.parametrize(
"custom_header",
[
[{"id": "abc"}],
[{"type": "abc"}],
[{"body": "abc"}],
[{"created_time": 1516269022}],
[{"created_time": "1516269022"}],
{"id": "abc"},
{"type": "abc"},
{"body": "abc"},
{"created_time": 1516269022},
{"created_time": "1516269022"},
],
)
async def test_custom_header_equals_to_default(custom_header):
async def test_custom_header_rejects_reserved_names(custom_header):
with pytest.raises(DIDCommValueError):
Message(
id="1234567890",
Expand Down Expand Up @@ -118,7 +118,12 @@ async def test_custom_header_equals_to_default(custom_header):
("from_prior", {}),
("attachments", {}),
("attachments", [{}]),
("custom_headers", {}),
("custom_headers", []),
("custom_headers", "some"),
( # reserved name cannot be used in custom headers
"custom_headers",
{"type": "something"},
),
],
)
async def test_message_invalid_types(new_fields):
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/pack_plaintext/test_pack_plaintext.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
attachment_multi_1_msg,
attachment_multi_2_msg,
ack_msg,
custom_headers_msg,
)
from tests.test_vectors.didcomm_messages.spec.spec_test_vectors_plaintext import (
TEST_PLAINTEXT_DIDCOMM_MESSAGE_SIMPLE,
Expand All @@ -24,6 +25,7 @@
TEST_PLAINTEXT_ATTACHMENT_MULTI_2,
TEST_PLAINTEXT_DIDCOMM_MESSAGE_MINIMAL,
TEST_PLAINTEXT_ACKS,
TEST_PLAINTEXT_DIDCOMM_MESSAGE_WITH_CUSTOM_HEADERS,
)


Expand All @@ -42,6 +44,15 @@ async def test_pack_simple_plaintext(resolvers_config_bob):
)


@pytest.mark.asyncio
async def test_pack_plaintext_with_custom_headers(resolvers_config_bob):
await check_pack_plaintext(
custom_headers_msg(),
TEST_PLAINTEXT_DIDCOMM_MESSAGE_WITH_CUSTOM_HEADERS,
resolvers_config_bob,
)


@pytest.mark.asyncio
async def test_pack_minimal_plaintext(resolvers_config_bob):
await check_pack_plaintext(
Expand Down

0 comments on commit ec4e806

Please sign in to comment.