From 5fc0d81509be642d049f6418a92e43ff37a3e460 Mon Sep 17 00:00:00 2001 From: Michael Nuzzo Date: Tue, 25 Jul 2023 12:31:02 -0400 Subject: [PATCH] #284 - Added a check for the contributor having a source before continuing. Will now update unit tests. --- backend/auth/jwt.py | 24 ++++++++++++++++++++ backend/database/models/_assoc_tables.py | 25 ++++++++++++++++++--- backend/database/models/source.py | 28 ++++-------------------- backend/database/models/user.py | 3 +++ backend/routes/incidents.py | 4 ++-- backend/tests/conftest.py | 5 +++++ backend/tests/test_incidents.py | 3 ++- 7 files changed, 62 insertions(+), 30 deletions(-) diff --git a/backend/auth/jwt.py b/backend/auth/jwt.py index 5ee15c09..95e96314 100644 --- a/backend/auth/jwt.py +++ b/backend/auth/jwt.py @@ -20,6 +20,19 @@ def verify_roles_or_abort(min_role): return True +def verify_contributor_has_source_or_abort(): + verify_jwt_in_request() + jwt_decoded = get_jwt() + current_user = User.get(jwt_decoded["sub"]) + if ( + current_user is None + or not current_user.member_of + ): + abort(403) + return False + return True + + def blueprint_role_required(*roles): def decorator(): verify_roles_or_abort(roles) @@ -45,3 +58,14 @@ def decorator(*args, **kwargs): return decorator return wrapper + +def contributor_has_source(): + def wrapper(fn): + @wraps(fn) + def decorator(*args, **kwargs): + if verify_contributor_has_source_or_abort(): + return fn(*args, **kwargs) + + return decorator + + return wrapper diff --git a/backend/database/models/_assoc_tables.py b/backend/database/models/_assoc_tables.py index d9900c7a..6907c4d0 100644 --- a/backend/database/models/_assoc_tables.py +++ b/backend/database/models/_assoc_tables.py @@ -1,8 +1,28 @@ from .. import db from backend.database.models.officer import Rank -# from backend.database.models.source import MemberRole +from enum import Enum -""" source_user = db.Table( + +class MemberRole(Enum): + ADMIN = "Administrator" + PUBLISHER = "Publisher" + MEMBER = "Member" + SUBSCRIBER = "Subscriber" + + def get_value(self): + if self == MemberRole.ADMIN: + return 1 + elif self == MemberRole.PUBLISHER: + return 2 + elif self == MemberRole.MEMBER: + return 3 + elif self == MemberRole.SUBSCRIBER: + return 4 + else: + return 5 + + +source_user = db.Table( 'source_user', db.Column('source_id', db.String, db.ForeignKey('source.id'), primary_key=True), @@ -13,7 +33,6 @@ db.Column('is_active', db.Boolean), db.Column('is_admin', db.Boolean) ) - """ incident_agency = db.Table( 'incident_agency', diff --git a/backend/database/models/source.py b/backend/database/models/source.py index d73fc73f..34c828fe 100644 --- a/backend/database/models/source.py +++ b/backend/database/models/source.py @@ -1,25 +1,5 @@ from ..core import db, CrudMixin -# from backend.database.models._assoc_tables import source_user -from enum import Enum - - -class MemberRole(Enum): - ADMIN = "Administrator" - PUBLISHER = "Publisher" - MEMBER = "Member" - SUBSCRIBER = "Subscriber" - - def get_value(self): - if self == MemberRole.ADMIN: - return 1 - elif self == MemberRole.PUBLISHER: - return 2 - elif self == MemberRole.MEMBER: - return 3 - elif self == MemberRole.SUBSCRIBER: - return 4 - else: - return 5 +from backend.database.models._assoc_tables import source_user class Source(db.Model, CrudMixin): @@ -29,6 +9,6 @@ class Source(db.Model, CrudMixin): contact_email = db.Column(db.Text) reported_incidents = db.relationship( 'Incident', backref='source', lazy="select") - # members = db.relationship( - # 'User', backref='contributor_orgs', - # secondary=source_user, lazy="select") + members = db.relationship( + 'User', backref='contributor_orgs', + secondary=source_user, lazy="select") diff --git a/backend/database/models/user.py b/backend/database/models/user.py index ff63ea9f..1af19998 100644 --- a/backend/database/models/user.py +++ b/backend/database/models/user.py @@ -85,5 +85,8 @@ class User(db.Model, UserMixin, CrudMixin): phone_number = db.Column(db.Text) + member_of = db.relationship( + 'Source', backref='source', secondary='source_user', lazy="select") + def verify_password(self, pw): return bcrypt.checkpw(pw.encode("utf8"), self.password.encode("utf8")) diff --git a/backend/routes/incidents.py b/backend/routes/incidents.py index 1a44aee9..7e96b06f 100644 --- a/backend/routes/incidents.py +++ b/backend/routes/incidents.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Optional -from backend.auth.jwt import min_role_required +from backend.auth.jwt import min_role_required, contributor_has_source from backend.database.models.user import UserRole from flask import Blueprint, abort, current_app, request from flask_jwt_extended.view_decorators import jwt_required @@ -30,8 +30,8 @@ def get_incidents(incident_id: int): @bp.route("/create", methods=["POST"]) @jwt_required() -# TODO: Require CONTRIBUTOR role @min_role_required(UserRole.CONTRIBUTOR) +@contributor_has_source() @validate(json=CreateIncidentSchema) def create_incident(): """Create a single incident. diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 16923a8c..335fdae4 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -64,6 +64,7 @@ def example_user(db_session): db_session.commit() return user + @pytest.fixture def admin_user(db_session): user = User( @@ -78,6 +79,7 @@ def admin_user(db_session): return user + @pytest.fixture def contributor_user(db_session): user = User( @@ -92,6 +94,7 @@ def contributor_user(db_session): return user + @pytest.fixture def access_token(client, example_user): res = client.post( @@ -104,6 +107,7 @@ def access_token(client, example_user): assert res.status_code == 200 return res.json["access_token"] + @pytest.fixture def contributor_access_token(client, contributor_user): res = client.post( @@ -116,6 +120,7 @@ def contributor_access_token(client, contributor_user): assert res.status_code == 200 return res.json["access_token"] + @pytest.fixture def cli_runner(app): return app.test_cli_runner() diff --git a/backend/tests/test_incidents.py b/backend/tests/test_incidents.py index 04a42f4b..7e693d69 100644 --- a/backend/tests/test_incidents.py +++ b/backend/tests/test_incidents.py @@ -61,7 +61,8 @@ def example_incidents(db_session, client, contributor_access_token): res = client.post( "/api/v1/incidents/create", json=mock, - headers={"Authorization": "Bearer {0}".format(contributor_access_token)}, + headers={"Authorization": + "Bearer {0}".format(contributor_access_token)}, ) assert res.status_code == 200 created[name] = res.json