Skip to content

Commit

Permalink
Refactoring Protocol document get(), finish unittest protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
claire-lex committed Nov 7, 2023
1 parent b5fd1f9 commit b661a53
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 27 deletions.
15 changes: 15 additions & 0 deletions srcs/db/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def protocols(self):
"""Return protocols collection."""
return self.db[mongodb.protocols]

@property
def protocols_id(self):
"""Return the list of IDs in protocols collection."""
return [x[mongodb.id] for x in self.db[mongodb.protocols].find()]

@property
def protocols_count(self):
"""Return number of protocols in collection."""
Expand All @@ -169,6 +174,11 @@ def links(self):
"""Return links collection."""
return self.db[mongodb.links]

@property
def links_id(self):
"""Return the list of IDs in links collection."""
return [x[mongodb.id] for x in self.db[mongodb.links].find()]

@property
def links_count(self):
"""Return number of links in collection."""
Expand All @@ -184,6 +194,11 @@ def packets(self):
"""Return packets collection."""
return self.db[mongodb.packets]

@property
def packets_id(self):
"""Return the list of IDs in packets collection."""
return [x[mongodb.id] for x in self.db[mongodb.packets].find()]

@property
def packets_count(self):
"""Return number of packets in collection."""
Expand Down
54 changes: 33 additions & 21 deletions srcs/db/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
ERR_EXIPROTO = "Protocol '{0}' already exists."
ERR_UNKFIELD = "Protocol '{0}' has no field '{1}'."
ERR_EXIVALUE = "Field '{0}' already contains this value."
ERR_INVVALUE = "Field '{0}' does not accept Documents."
ERR_INVVALUE = "Field '{0}' does not accept links or packets."
ERR_INVLINK = "Field '{0}' only accepts valid links."
ERR_INVPACKET = "Field '{0}' only accepts valid packets."
ERR_MULTIMATCH = "Multiple match found, please choose between {0}."
ERR_BOOLVALUE = "This field only accept 'true' or 'false'"

Expand All @@ -36,6 +38,15 @@ def __init__(self, **kwargs):
setattr(self, k, v)
self.__fill()

def __store(self, field: str, value: object) -> None:
"""Save value to the Document in DB and in this class."""
if isinstance(value, Document):
raise DBException(ERR_INVVALUE.format(p.NAME(field)))
document = {"name": self.name}
newvalue = {field: value}
self._db.protocols.update_one(document, {"$set": newvalue})
setattr(self, field, value)

#--- Public --------------------------------------------------------------#

def get(self, field: str) -> tuple:
Expand All @@ -56,26 +67,27 @@ def get(self, field: str) -> tuple:
def set(self, field: str, value: object, replace: bool = False) -> None:
"""Update existing field in protocol."""
field, oldvalue = self.get(field)
# Different behavior if linklist
if p.TYPE(field) in (types.LINKLIST, types.LIST, types.PKTLIST):
if isinstance(value, Document): # Link or Packet
value = value._id
if not replace and oldvalue: # We append
oldvalue = [oldvalue] if not isinstance(oldvalue, list) else oldvalue
if value not in oldvalue:
value = [value] if not isinstance(value, list) else value
value = [x for x in oldvalue + value if x != '']
else:
raise DBException(ERR_EXIVALUE.format(p.NAME(field)))
else:
value = value if isinstance(value, list) else [value]
if isinstance(value, Document): # Link or Packet
raise DBException(ERR_INVVALUE.format(p.NAME(field)))
# Store
document = {"name": self.name}
newvalue = {field: value}
self._db.protocols.update_one(document, {"$set": newvalue})
setattr(self, field, value)
ftype = p.TYPE(field)
# We deal with the simplest case first
if not ftype or ftype == types.STR:
return self.__store(field, value)
# Is the value a link or a packet ?
if isinstance(value, Document):
value = value._id
if ftype == types.LINKLIST and value not in self._db.links_id:
raise DBException(ERR_INVLINK.format(p.NAME(field)))
elif ftype == types.PKTLIST and value not in self._db.packets_id:
raise DBException(ERR_INVPACKET.format(p.NAME(field)))
# All other fields are lists (LIST, LINKLIST, PKTLIST)
value = value if isinstance(value, list) else [value]
if replace:
value = value if isinstance(value, list) else [value]
elif not replace and oldvalue:
oldvalue = oldvalue if isinstance(oldvalue, list) else [oldvalue]
if set(oldvalue) & set(value):
raise DBException(ERR_EXIVALUE.format(p.NAME(field)))
value = [x for x in oldvalue + value if x != '']
self.__store(field, value)

def add(self, field: str, value: object) -> None:
"""Add a new field to protocol."""
Expand Down
34 changes: 29 additions & 5 deletions srcs/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ def test_0303_getprotocols_all(self):
"""All protocols can be returned as JSON."""
self.assertEqual(self.protocols.all[0]["name"], PROTOCOLS[0])
def test_0304_getprotocols_all(self):
"""All protocols can be returned as JSON."""
"""All protocols can be returned as objects."""
self.assertTrue(isinstance(self.protocols.all_as_objects[0], Protocol))
self.assertEqual(self.protocols.all_as_objects[0].name, PROTOCOLS[0])
def test_0305_getprotocols_list(self):
"""All protocols can be returned as JSON."""
"""All protocols can be returned as list."""
self.assertEqual(self.protocols.list[0], PROTOCOLS[0])
def test_0306_getprotocols_count(self):
"""All protocols can be returned as JSON."""
"""The number of protocols can be returned."""
self.assertEqual(self.protocols.count, 1)
def test_0307_getprotocols(self):
"""We can get a protocol by its name."""
Expand Down Expand Up @@ -218,7 +218,6 @@ class Test04DBProtocolDocument(DBTest):
@classmethod
def setUpClass(self):
super().setUpClass()
# We need to init protocols because a packet is attached to one.
populate(self.protocols)
def test_0401_getprotocol_field(self):
"""We can get a value from its field."""
Expand Down Expand Up @@ -338,4 +337,29 @@ def test_0420_setprotocol_nolinkpkt(self):
protocol = self.protocols.get(PROTOCOLS[1])
link = self.links.get(LINKS[0])
with self.assertRaises(DBException):
protocol.set("nmap", link)
protocol.set("discovery", link)

class Test05DBLinksCollection(DBTest):
"""Test class to get and set data in links' collection."""
def test_0501_addlinks(self):
"""A new link can be added."""
self.links.add(Link(**TEST_COLL_LINKS[0]))
link = self.links.get(TEST_COLL_LINKS[0]["name"])
self.assertEqual(link.name, LINKS[0])
def test_0502_addlinks_exists(self):
"""We can't add a link that already exists."""
with self.assertRaises(DBException):
self.links.add(Link(**TEST_COLL_LINKS[0]))
# """All links can be returned as JSON."""
# """All links can be returned as objects."""
# """All links can be returned as list."""
# """The number of links can be returned."""
# """We can get a link by its name."""
# """We can get a link by its name."""
# """We can't get a link that does not exists."""
# """A search can return several links."""
# """Method has() tells us if the link exists or not."""
# """An complete link passes check()."""
# """We can delete an existing link."""
# """We cannot delete a link that does not exist."""
# """Deleting a link erases all its references in protocols."""
2 changes: 1 addition & 1 deletion srcs/ui/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __cmd_add(self, protocol: str = None) -> bool:
try:
self.protocols.add(Protocol(name=protocol))
except DBException as dbe:
raise UIError(dbe) from None
ERROR(str(dbe), will_exit=True)
self.__cmd_read(protocol)
return True

Expand Down

0 comments on commit b661a53

Please sign in to comment.