From ec4e80605373a2dc0f0794df59c8aca863abe5b7 Mon Sep 17 00:00:00 2001 From: Darko Kulic <36031644+dkulic@users.noreply.github.com> Date: Mon, 13 Mar 2023 10:38:48 +0100 Subject: [PATCH] Implement support for custom_headers. (#94) * Implement support for custom_headers. * Fix NoneType problem * Fix validation --- didcomm/common/types.py | 2 +- didcomm/message.py | 61 +++++++++++-------- didcomm/pack_encrypted.py | 4 +- didcomm/protocols/routing/forward.py | 4 +- .../test_vectors/didcomm_messages/messages.py | 14 +++++ .../tests/test_vectors_plaintext_positive.py | 21 +++++++ .../unit/pack_common/test_message_negative.py | 19 +++--- .../pack_plaintext/test_pack_plaintext.py | 11 ++++ 8 files changed, 100 insertions(+), 36 deletions(-) diff --git a/didcomm/common/types.py b/didcomm/common/types.py index 1c1f53d..8431b0c 100644 --- a/didcomm/common/types.py +++ b/didcomm/common/types.py @@ -5,7 +5,7 @@ from typing import Dict, Any, Union, List JSON_OBJ = Dict[str, Any] -JSON_VALUE = Union[None, str, int, bool, float, JSON_OBJ, List[Any]] +JSON_VALUE = Union[type(None), str, int, bool, float, Dict, List] JSON = str JWK = JSON JWT = JSON diff --git a/didcomm/message.py b/didcomm/message.py index b432a4d..62779a2 100644 --- a/didcomm/message.py +++ b/didcomm/message.py @@ -32,24 +32,26 @@ DIDCommValueError, ) +NoneType = type(None) HeaderKeyType = str HeaderValueType = JSON_VALUE -Header = Dict[HeaderKeyType, HeaderValueType] +Headers = Dict[HeaderKeyType, HeaderValueType] T = TypeVar("T") -_MESSAGE_DEFAULT_FIELDS = { - "id", - "type", +_MESSAGE_RESERVED_FIELDS = { + "ack", + "attachments", "body", - "frm", - "to", "created_time", "expires_time", + "frm", + "from", "from_prior", + "id", "please_ack", - "ack", - "thid", "pthid", - "attachments", + "thid", + "to", + "type", } @@ -333,19 +335,18 @@ class GenericMessage(Generic[T]): ), default=None, ) - custom_headers: Optional[List[Header]] = attr.ib( + custom_headers: Optional[Headers] = attr.ib( validator=validator__optional( - validator__deep_iterable( - validator__deep_mapping( - key_validator=validator__and_( - validator__instance_of(HeaderKeyType), - validator__not_in_(_MESSAGE_DEFAULT_FIELDS), - ), - value_validator=validator__instance_of(HeaderValueType), - mapping_validator=validator__instance_of(Dict), + validator__deep_mapping( + key_validator=validator__and_( + validator__instance_of(HeaderKeyType), + validator__not_in_(_MESSAGE_RESERVED_FIELDS), ), - iterable_validator=validator__instance_of(List), - ) + value_validator=validator__instance_of( + (str, int, bool, float, Dict, List, NoneType) + ), + mapping_validator=validator__instance_of(Dict), + ), ), default=None, ) @@ -364,6 +365,11 @@ def _body_as_dict(self): return self.body def as_dict(self) -> dict: + try: + attr.validate(self) + except Exception as exc: + raise DIDCommValueError(str(exc)) from exc + d = attrs_to_dict(self) d["body"] = self._body_as_dict() @@ -379,10 +385,9 @@ def as_dict(self) -> dict: if self.from_prior: d["from_prior"] = self.from_prior.as_dict() - try: - attr.validate(self) - except Exception as exc: - raise DIDCommValueError(str(exc)) from exc + if self.custom_headers: + del d["custom_headers"] + d.update(self.custom_headers) return d @@ -413,6 +418,14 @@ def from_dict(cls, d: dict) -> Message: raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT) del d["typ"] + custom_header_keys = d.keys() - _MESSAGE_RESERVED_FIELDS + if custom_header_keys: + custom_headers = {} + for key in custom_header_keys: + custom_headers[key] = d[key] + del d[key] + d["custom_headers"] = custom_headers + if "from" in d: d["frm"] = d["from"] del d["from"] diff --git a/didcomm/pack_encrypted.py b/didcomm/pack_encrypted.py index d1603d8..214df76 100644 --- a/didcomm/pack_encrypted.py +++ b/didcomm/pack_encrypted.py @@ -17,7 +17,7 @@ from didcomm.did_doc.did_doc import DIDCommService from didcomm.errors import DIDCommValueError from didcomm.core.from_prior import pack_from_prior_in_place -from didcomm.message import Message, Header +from didcomm.message import Message, Headers from didcomm.protocols.routing.forward import ( wrap_in_forward, resolve_did_services_chain, @@ -261,7 +261,7 @@ class PackEncryptedParameters: in the packed message """ - forward_headers: Optional[Header] = None + forward_headers: Optional[Headers] = None forward_service_id: Optional[str] = None forward_didcomm_id_generator: Optional[ DIDCommGeneratorType diff --git a/didcomm/protocols/routing/forward.py b/didcomm/protocols/routing/forward.py index 96c79be..7cb614f 100644 --- a/didcomm/protocols/routing/forward.py +++ b/didcomm/protocols/routing/forward.py @@ -24,7 +24,7 @@ ) from didcomm.common.resolvers import ResolversConfig from didcomm.common.algorithms import AnonCryptAlg -from didcomm.message import GenericMessage, Header, Attachment, AttachmentDataJson +from didcomm.message import GenericMessage, Headers, Attachment, AttachmentDataJson from didcomm.core.types import EncryptResult, DIDCommGeneratorType, DIDCOMM_ORG_DOMAIN from didcomm.core.defaults import DEF_ENC_ALG_ANON from didcomm.core.validators import ( @@ -209,7 +209,7 @@ async def wrap_in_forward( to: DID_OR_DID_URL, routing_keys: List[DID_OR_DID_URL], enc_alg_anon: Optional[AnonCryptAlg] = DEF_ENC_ALG_ANON, - headers: Optional[Header] = None, + headers: Optional[Headers] = None, didcomm_id_generator: Optional[DIDCommGeneratorType] = None, ) -> Optional[ForwardPackResult]: """ diff --git a/tests/test_vectors/didcomm_messages/messages.py b/tests/test_vectors/didcomm_messages/messages.py index b09f60c..ebd9828 100644 --- a/tests/test_vectors/didcomm_messages/messages.py +++ b/tests/test_vectors/didcomm_messages/messages.py @@ -143,6 +143,20 @@ def minimal_msg(): ) +def custom_headers_msg(): + msg = copy.deepcopy(TEST_MESSAGE) + msg.custom_headers = { + "my_string": "string value", + "my_int": 123, + "my_bool": False, + "my_float": 1.23, + "my_json": {"key": "value"}, + "my_list": [1, 2, 3], + "my_none": None, + } + return msg + + def attachment_base64_msg(): msg = copy.deepcopy(TEST_MESSAGE) msg.attachments = [Attachment(id="23", data=AttachmentDataBase64(base64="qwerty"))] diff --git a/tests/test_vectors/didcomm_messages/tests/test_vectors_plaintext_positive.py b/tests/test_vectors/didcomm_messages/tests/test_vectors_plaintext_positive.py index c4d6859..ad3a907 100644 --- a/tests/test_vectors/didcomm_messages/tests/test_vectors_plaintext_positive.py +++ b/tests/test_vectors/didcomm_messages/tests/test_vectors_plaintext_positive.py @@ -32,6 +32,27 @@ } ) +TEST_PLAINTEXT_DIDCOMM_MESSAGE_WITH_CUSTOM_HEADERS = json_dumps( + { + "id": "1234567890", + "thid": "1234567890", + "typ": "application/didcomm-plain+json", + "type": "http://example.com/protocols/lets_do_lunch/1.0/proposal", + "from": "did:example:alice", + "to": ["did:example:bob"], + "created_time": 1516269022, + "expires_time": 1516385931, + "body": {"messagespecificattribute": "and its value"}, + "my_string": "string value", + "my_int": 123, + "my_bool": False, + "my_float": 1.23, + "my_json": {"key": "value"}, + "my_list": [1, 2, 3], + "my_none": None, + } +) + TEST_PLAINTEXT_ATTACHMENT_BASE64 = json_dumps( { "id": "1234567890", diff --git a/tests/unit/pack_common/test_message_negative.py b/tests/unit/pack_common/test_message_negative.py index 96fcfa8..6b444eb 100644 --- a/tests/unit/pack_common/test_message_negative.py +++ b/tests/unit/pack_common/test_message_negative.py @@ -83,14 +83,14 @@ async def test_no_required_param(resolvers_config_alice): @pytest.mark.parametrize( "custom_header", [ - [{"id": "abc"}], - [{"type": "abc"}], - [{"body": "abc"}], - [{"created_time": 1516269022}], - [{"created_time": "1516269022"}], + {"id": "abc"}, + {"type": "abc"}, + {"body": "abc"}, + {"created_time": 1516269022}, + {"created_time": "1516269022"}, ], ) -async def test_custom_header_equals_to_default(custom_header): +async def test_custom_header_rejects_reserved_names(custom_header): with pytest.raises(DIDCommValueError): Message( id="1234567890", @@ -118,7 +118,12 @@ async def test_custom_header_equals_to_default(custom_header): ("from_prior", {}), ("attachments", {}), ("attachments", [{}]), - ("custom_headers", {}), + ("custom_headers", []), + ("custom_headers", "some"), + ( # reserved name cannot be used in custom headers + "custom_headers", + {"type": "something"}, + ), ], ) async def test_message_invalid_types(new_fields): diff --git a/tests/unit/pack_plaintext/test_pack_plaintext.py b/tests/unit/pack_plaintext/test_pack_plaintext.py index 270551e..6829b95 100644 --- a/tests/unit/pack_plaintext/test_pack_plaintext.py +++ b/tests/unit/pack_plaintext/test_pack_plaintext.py @@ -12,6 +12,7 @@ attachment_multi_1_msg, attachment_multi_2_msg, ack_msg, + custom_headers_msg, ) from tests.test_vectors.didcomm_messages.spec.spec_test_vectors_plaintext import ( TEST_PLAINTEXT_DIDCOMM_MESSAGE_SIMPLE, @@ -24,6 +25,7 @@ TEST_PLAINTEXT_ATTACHMENT_MULTI_2, TEST_PLAINTEXT_DIDCOMM_MESSAGE_MINIMAL, TEST_PLAINTEXT_ACKS, + TEST_PLAINTEXT_DIDCOMM_MESSAGE_WITH_CUSTOM_HEADERS, ) @@ -42,6 +44,15 @@ async def test_pack_simple_plaintext(resolvers_config_bob): ) +@pytest.mark.asyncio +async def test_pack_plaintext_with_custom_headers(resolvers_config_bob): + await check_pack_plaintext( + custom_headers_msg(), + TEST_PLAINTEXT_DIDCOMM_MESSAGE_WITH_CUSTOM_HEADERS, + resolvers_config_bob, + ) + + @pytest.mark.asyncio async def test_pack_minimal_plaintext(resolvers_config_bob): await check_pack_plaintext(