Skip to content

Commit

Permalink
Rewrite Attachments (and its subclasses) validation using 'attrs' val…
Browse files Browse the repository at this point in the history
…idators

Signed-off-by: Sacha Kozma <[email protected]>
  • Loading branch information
yvgny committed Mar 1, 2023
1 parent 820c591 commit e4393fe
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 127 deletions.
170 changes: 92 additions & 78 deletions didcomm/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, List, Union, Dict, TypeVar, Generic, Callable

import attr
import attrs

from didcomm.common.types import (
JSON_VALUE,
Expand Down Expand Up @@ -34,7 +35,7 @@
HeaderValueType = JSON_VALUE
Header = Dict[HeaderKeyType, HeaderValueType]
T = TypeVar("T")
MESSAGE_DEFAULT_FIELDS = {
_MESSAGE_DEFAULT_FIELDS = {
"id",
"type",
"body",
Expand All @@ -51,78 +52,18 @@
}


@attr.s(auto_attribs=True)
class Attachment:
"""Plaintext attachment"""

data: Union[AttachmentDataLinks, AttachmentDataBase64, AttachmentDataJson]
id: Optional[Union[str, Callable]] = attr.ib(
converter=converter__id, validator=validator__instance_of(str), default=None
)
description: Optional[str] = None
filename: Optional[str] = None
media_type: Optional[str] = None
format: Optional[str] = None
lastmod_time: Optional[int] = None
byte_count: Optional[int] = None

def as_dict(self) -> dict:
self._validate()
d = attrs_to_dict(self)
d["data"] = self.data.as_dict()
return d

@staticmethod
def from_dict(d: dict) -> Attachment:
if not isinstance(d, Dict):
raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT)

if "data" in d:
if not isinstance(d["data"], Dict):
raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT)
if "links" in d["data"]:
d["data"] = AttachmentDataLinks.from_dict(d["data"])
elif "base64" in d["data"]:
d["data"] = AttachmentDataBase64.from_dict(d["data"])
elif "json" in d["data"]:
d["data"] = AttachmentDataJson.from_dict(d["data"])

try:
msg = Attachment(**d)
msg._validate()
except Exception:
raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT)

return msg

def _validate(self):
if (
not isinstance(self.id, str)
or not isinstance(
self.data,
(AttachmentDataLinks, AttachmentDataBase64, AttachmentDataJson),
)
or self.description is not None
and not isinstance(self.description, str)
or self.filename is not None
and not isinstance(self.filename, str)
or self.media_type is not None
and not isinstance(self.media_type, str)
or self.format is not None
and not isinstance(self.format, str)
or self.lastmod_time is not None
and not isinstance(self.lastmod_time, int)
or self.byte_count is not None
and not isinstance(self.byte_count, int)
):
raise DIDCommValueError(f"Attachment structure is invalid: {self}")


@dataclass
@attrs.define(auto_attribs=True)
class AttachmentDataLinks:
links: List[str]
hash: str
jws: Optional[JSON_OBJ] = None
links: List[str] = attr.ib(
validator=validator__deep_iterable(
validator__instance_of(str), validator__instance_of(List)
)
)
hash: str = attr.ib(validator=validator__instance_of(str))
jws: Optional[JSON_OBJ] = attr.ib(
validator=validator__optional(validator__instance_of(Dict)), default=None
)

def as_dict(self) -> dict:
self._validate()
Expand Down Expand Up @@ -152,10 +93,15 @@ def _validate(self):


@dataclass
@attrs.define(auto_attribs=True)
class AttachmentDataBase64:
base64: str
hash: Optional[str] = None
jws: Optional[JSON_OBJ] = None
base64: str = attr.ib(validator=validator__instance_of(str))
hash: Optional[str] = attr.ib(
validator=validator__optional(validator__instance_of(str)), default=None
)
jws: Optional[JSON_OBJ] = attr.ib(
validator=validator__optional(validator__instance_of(Dict)), default=None
)

def as_dict(self) -> dict:
self._validate()
Expand Down Expand Up @@ -185,10 +131,17 @@ def _validate(self):


@dataclass
@attrs.define(auto_attribs=True)
class AttachmentDataJson:
json: JSON_VALUE
hash: Optional[str] = None
jws: Optional[JSON_OBJ] = None
json: JSON_VALUE = attr.ib(
validator=validator__instance_of((str, int, bool, float, Dict, List))
)
hash: Optional[str] = attr.ib(
validator=validator__optional(validator__instance_of(str)), default=None
)
jws: Optional[JSON_OBJ] = attr.ib(
validator=validator__optional(validator__instance_of(Dict)), default=None
)

def as_dict(self) -> dict:
self._validate()
Expand All @@ -215,6 +168,67 @@ def _validate(self):
raise DIDCommValueError(f"AttachmentDataJson structure is invalid: {self}")


@attrs.define(auto_attribs=True)
class Attachment:
"""Plaintext attachment"""

data: Union[
AttachmentDataLinks, AttachmentDataBase64, AttachmentDataJson
] = attr.ib(
validator=validator__instance_of(
Union[AttachmentDataLinks, AttachmentDataBase64, AttachmentDataJson]
),
)
id: Optional[Union[str, Callable]] = attr.ib(
converter=converter__id, validator=validator__instance_of(str), default=None
)
description: Optional[str] = attr.ib(
validator=validator__optional(validator__instance_of(str)), default=None
)
filename: Optional[str] = attr.ib(
validator=validator__optional(validator__instance_of(str)), default=None
)
media_type: Optional[str] = attr.ib(
validator=validator__optional(validator__instance_of(str)), default=None
)
format: Optional[str] = attr.ib(
validator=validator__optional(validator__instance_of(str)), default=None
)
lastmod_time: Optional[int] = attr.ib(
validator=validator__optional(validator__instance_of(int)), default=None
)
byte_count: Optional[int] = attr.ib(
validator=validator__optional(validator__instance_of(int)), default=None
)

def as_dict(self) -> dict:
d = attrs_to_dict(self)
d["data"] = self.data.as_dict()
return d

@staticmethod
def from_dict(d: dict) -> Attachment:
if not isinstance(d, Dict):
raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT)

if "data" in d:
if not isinstance(d["data"], Dict):
raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT)
if "links" in d["data"]:
d["data"] = AttachmentDataLinks.from_dict(d["data"])
elif "base64" in d["data"]:
d["data"] = AttachmentDataBase64.from_dict(d["data"])
elif "json" in d["data"]:
d["data"] = AttachmentDataJson.from_dict(d["data"])

try:
msg = Attachment(**d)
except Exception:
raise MalformedMessageError(MalformedMessageCode.INVALID_PLAINTEXT)

return msg


@dataclass(frozen=True)
class FromPrior:
iss: DID
Expand Down Expand Up @@ -342,7 +356,7 @@ class GenericMessage(Generic[T]):
validator__deep_mapping(
key_validator=validator__and_(
validator__instance_of(HeaderKeyType),
validator__not_in_(MESSAGE_DEFAULT_FIELDS),
validator__not_in_(_MESSAGE_DEFAULT_FIELDS),
),
value_validator=validator__instance_of(HeaderValueType),
mapping_validator=validator__instance_of(Dict),
Expand Down
1 change: 1 addition & 0 deletions didcomm/protocols/routing/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ async def wrap_in_forward(
# wrap forward msgs in reversed order so the message to final
# recipient 'to' will be the innermost one
for _to, _next in zip(routing_keys[::-1], (routing_keys[1:] + [to])[::-1]):
print(packed_msg, type(packed_msg))
fwd_attach = Attachment(data=AttachmentDataJson(packed_msg))

fwd_msg = ForwardMessage(
Expand Down
10 changes: 0 additions & 10 deletions tests/unit/didcomm/protocols/routing/test_forward_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,6 @@ def test_forward_message__type_good(msg_ver, fwd_msg):
),
id="bad_att_type",
),
pytest.param(
gen_fwd_msg_dict(
update={
DIDCommFields.ATTACHMENTS: [
Attachment(data=AttachmentDataJson({"somemsg"}))
]
}
),
id="bad_att_value_type",
),
],
)
def test_forward_message_from_dict__bad_msg(msg):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def encrypted_msg2() -> dict:
@pytest.fixture
def encrypt_result1(did: DID, encrypted_msg1: dict) -> EncryptResult:
return EncryptResult(
msg=encrypt_result1, to_kids=[did], to_keys=["not", "important", "for", "now"]
msg=encrypted_msg1, to_kids=[did], to_keys=["not", "important", "for", "now"]
)


Expand Down
67 changes: 29 additions & 38 deletions tests/unit/pack_common/test_message_negative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses

import attr
import attrs
import pytest

from didcomm.errors import DIDCommValueError
Expand Down Expand Up @@ -38,8 +39,7 @@ async def check_invalid_pack_msg(msg: Message, resolvers_config_alice):
def update_attachment_field(attcmnt, field_name, value):
msg = copy.deepcopy(TEST_MESSAGE)
attcmnt = copy.deepcopy(attcmnt)
# to hack / workaround Attachment frozen setting
object.__setattr__(attcmnt, field_name, value)
attrs.evolve(attcmnt, **{field_name: value})
msg.attachments = [attcmnt]
return msg

Expand Down Expand Up @@ -158,8 +158,8 @@ async def test_message_invalid_body_type(msg, resolvers_config_alice):
async def test_message_invalid_attachemnt_fields(
attachment, new_fields, resolvers_config_alice
):
msg = update_attachment_field(attachment, *new_fields)
await check_invalid_pack_msg(msg, resolvers_config_alice)
with pytest.raises(DIDCommValueError):
update_attachment_field(attachment, *new_fields)


@pytest.mark.asyncio
Expand Down Expand Up @@ -187,55 +187,46 @@ async def test_message_invalid_from_prior_fields(


@pytest.mark.asyncio
@pytest.mark.parametrize("attachment", [TEST_ATTACHMENT, TEST_ATTACHMENT_MINIMAL])
@pytest.mark.parametrize(
"links",
"links, hashh, jws",
[
AttachmentDataLinks(links={}, hash="abc"),
AttachmentDataLinks(links=[123], hash="abc"),
AttachmentDataLinks(links=[123], hash="abc"),
AttachmentDataLinks(links=["123"], hash=123),
AttachmentDataLinks(links=["123"], hash="123", jws=123),
AttachmentDataLinks(links=["123"], hash="123", jws="123"),
AttachmentDataLinks(links=["123"], hash="123", jws=[]),
({}, "abc", None),
([123], "abc", None),
([123], "abc", None),
(["123"], 123, None),
(["123"], "123", 123),
(["123"], "123", "123"),
(["123"], "123", []),
],
)
async def test_message_invalid_attachment_data_links(
attachment, links, resolvers_config_alice
):
msg = update_attachment_field(attachment, "data", links)
await check_invalid_pack_msg(msg, resolvers_config_alice)
async def test_message_invalid_attachment_data_links(links, hashh, jws):
with pytest.raises(DIDCommValueError):
AttachmentDataLinks(links, hashh, jws)


@pytest.mark.asyncio
@pytest.mark.parametrize("attachment", [TEST_ATTACHMENT, TEST_ATTACHMENT_MINIMAL])
@pytest.mark.parametrize(
"base64",
"base64, hashh, jws",
[
AttachmentDataBase64(base64=123),
AttachmentDataBase64(base64="123", hash=123),
AttachmentDataBase64(base64="123", hash="123", jws="{}"),
(123, None, None),
("123", 123, None),
("123", "123", "{}"),
],
)
async def test_message_invalid_attachment_data_base64(
attachment, base64, resolvers_config_alice
):
msg = update_attachment_field(attachment, "data", base64)
await check_invalid_pack_msg(msg, resolvers_config_alice)
async def test_message_invalid_attachment_data_base64(base64, hashh, jws):
with pytest.raises(DIDCommValueError):
AttachmentDataBase64(base64, hashh, jws)


@pytest.mark.asyncio
@pytest.mark.parametrize("attachment", [TEST_ATTACHMENT, TEST_ATTACHMENT_MINIMAL])
@pytest.mark.parametrize(
"json_data",
"json_data, hashh, jws",
[
AttachmentDataJson(json=AttachmentDataJson(json={})),
AttachmentDataJson(json={}, hash=123),
AttachmentDataJson(json={}, hash="123", jws="{}"),
(AttachmentDataJson(json={}), None, None),
({}, 123, None),
({}, "123", "{}"),
],
)
async def test_message_invalid_attachment_data_json(
attachment, json_data, resolvers_config_alice
):
msg = update_attachment_field(attachment, "data", json_data)
await check_invalid_pack_msg(msg, resolvers_config_alice)
async def test_message_invalid_attachment_data_json(json_data, hashh, jws):
with pytest.raises(DIDCommValueError):
AttachmentDataJson(json_data, hashh, jws)

0 comments on commit e4393fe

Please sign in to comment.