Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronn committed Jun 3, 2017
1 parent 74dea49 commit 36d481e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 31 deletions.
27 changes: 9 additions & 18 deletions drfpasswordless/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,21 @@ def alias_type(self):
raise NotImplementedError

def validate(self, attrs):
alias = attrs.get(self.alias_type)

if alias:
msg = _('There was a problem with your request.')

if self.alias_type:
# Get request.user
# Get their specified valid endpoint
# Validate

request = self.context.get("request")
request = self.context["request"]
if request and hasattr(request, "user"):
user = request.user
user = user.refresh_from_db()

if user:
if not user.is_active:
# If valid, return attrs so we can create a token in our logic controller
msg = _('User account is disabled.')
print(msg)
log.debug(msg)

else:
if hasattr(user, self.alias_type):
Expand All @@ -136,28 +133,20 @@ def validate(self, attrs):
return attrs
else:
msg = _('This user doesn\'t have an %s.' % self.alias_type)
print(msg)
log.debug(msg)
raise serializers.ValidationError(msg)
else:
msg = _('There was a problem with your request.')
print(msg)
log.debug(msg)
raise serializers.ValidationError(msg)
else:
msg = _('Missing %s.') % self.alias_type
print(msg)
log.debug(msg)
raise serializers.ValidationError(msg)


class EmailVerificationSerializer(AbstractBaseAliasAuthenticationSerializer):
class EmailVerificationSerializer(AbstractBaseAliasVerificationSerializer):
@property
def alias_type(self):
return 'email'


class MobileVerificationSerializer(AbstractBaseAliasAuthenticationSerializer):
class MobileVerificationSerializer(AbstractBaseAliasVerificationSerializer):
@property
def alias_type(self):
return 'mobile'
Expand Down Expand Up @@ -232,10 +221,12 @@ class CallbackTokenVerificationSerializer(AbstractBaseCallbackTokenSerializer):

def validate(self, attrs):
try:
print(self.context)
user_id = self.context.get("user_id")
callback_token = attrs.get('token', None)

token = CallbackToken.objects.get(key=callback_token, is_active=True)
user = User.objects.get(pk=self.context.get("user_id"))
user = User.objects.get(pk=user_id)

if token.user == user:
# Check that the token.user is the request.user
Expand Down
7 changes: 4 additions & 3 deletions drfpasswordless/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from rest_framework import parsers, renderers, status
from rest_framework.authtoken.models import Token
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
from .settings import api_settings
from .serializers import (EmailAuthSerializer,
Expand Down Expand Up @@ -65,10 +66,8 @@ def post(self, request, *args, **kwargs):
else:
status_code = status.HTTP_400_BAD_REQUEST
response_detail = self.failure_response
log.debug("FAIL")
return Response({'detail': response_detail}, status=status_code)
else:
log.debug(serializer.error_messages)
return Response(serializer.error_messages, status=status.HTTP_400_BAD_REQUEST)


Expand Down Expand Up @@ -101,6 +100,7 @@ class ObtainMobileCallbackToken(AbstractBaseObtainCallbackToken):


class ObtainEmailVerificationCallbackToken(AbstractBaseObtainCallbackToken):
permission_classes = (IsAuthenticated,)
serializer_class = EmailVerificationSerializer
send_action = send_email_with_callback_token
success_response = "A verification token has been sent to your email."
Expand All @@ -117,6 +117,7 @@ class ObtainEmailVerificationCallbackToken(AbstractBaseObtainCallbackToken):


class ObtainMobileVerificationCallbackToken(AbstractBaseObtainCallbackToken):
permission_classes = (IsAuthenticated,)
serializer_class = MobileVerificationSerializer
send_action = send_sms_with_callback_token
success_response = "We texted you a verification code."
Expand Down Expand Up @@ -180,7 +181,7 @@ class VerifyAliasFromCallbackToken(APIView):
serializer_class = CallbackTokenVerificationSerializer

def post(self, request, *args, **kwargs):
serializer = self.serializer_class(data=request.data, context={'user_id', self.request.user.id})
serializer = self.serializer_class(data=request.data, context={'user_id': self.request.user.id})
if serializer.is_valid(raise_exception=True):

return Response({'detail': 'Alias verified.'}, status=status.HTTP_200_OK)
Expand Down
23 changes: 13 additions & 10 deletions tests/test_verification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rest_framework import status
from rest_framework.authtoken.models import Token
from rest_framework.test import APITestCase

from django.contrib.auth import get_user_model
from drfpasswordless.settings import api_settings, DEFAULTS
from drfpasswordless.utils import CallbackToken
Expand All @@ -25,6 +24,7 @@ def setUp(self):

def test_email_unverified_to_verified_and_back(self):
email = '[email protected]'
email2 = '[email protected]'
data = {'email': email}

# create a new user
Expand All @@ -49,17 +49,18 @@ def test_email_unverified_to_verified_and_back(self):
self.assertEqual(getattr(user, self.email_verified_field_name), True)

# Change email, should result in flag changing to false
setattr(user, self.email_field_name, '[email protected]')
setattr(user, self.email_field_name, email2)
user.save()
user.refresh_from_db()
self.assertEqual(getattr(user, self.email_verified_field_name), False)

# Verify
callback_response = self.client.post(self.verify_url)
self.assertEqual(callback_response.status_code, status.HTTP_200_OK)
self.client.force_login(user)
verify_response = self.client.post(self.verify_url)
self.assertEqual(verify_response.status_code, status.HTTP_200_OK)

# Refresh User
user = User.objects.get(**{self.email_field_name: email})
user = User.objects.get(**{self.email_field_name: email2})
self.assertNotEqual(user, None)
self.assertNotEqual(getattr(user, self.email_field_name), None)
self.assertEqual(getattr(user, self.email_verified_field_name), False)
Expand All @@ -70,7 +71,7 @@ def test_email_unverified_to_verified_and_back(self):
self.assertEqual(verify_callback_response.status_code, status.HTTP_200_OK)

# Refresh User
user = User.objects.get(**{self.email_field_name: email})
user = User.objects.get(**{self.email_field_name: email2})
self.assertNotEqual(user, None)
self.assertNotEqual(getattr(user, self.email_field_name), None)
self.assertEqual(getattr(user, self.email_verified_field_name), True)
Expand Down Expand Up @@ -98,6 +99,7 @@ def setUp(self):

def test_mobile_unverified_to_verified_and_back(self):
mobile = '+15551234567'
mobile2 = '+15557654321'
data = {'mobile': mobile}

# create a new user
Expand Down Expand Up @@ -128,11 +130,12 @@ def test_mobile_unverified_to_verified_and_back(self):
self.assertEqual(getattr(user, self.mobile_verified_field_name), False)

# Verify
callback_response = self.client.post(self.verify_url)
self.assertEqual(callback_response.status_code, status.HTTP_200_OK)
self.client.force_login(user)
verify_response = self.client.post(self.verify_url)
self.assertEqual(verify_response.status_code, status.HTTP_200_OK)

# Refresh User
user = User.objects.get(**{self.mobile_field_name: mobile})
user = User.objects.get(**{self.mobile_field_name: mobile2})
self.assertNotEqual(user, None)
self.assertNotEqual(getattr(user, self.mobile_field_name), None)
self.assertEqual(getattr(user, self.mobile_verified_field_name), False)
Expand All @@ -143,7 +146,7 @@ def test_mobile_unverified_to_verified_and_back(self):
self.assertEqual(verify_callback_response.status_code, status.HTTP_200_OK)

# Refresh User
user = User.objects.get(**{self.mobile_field_name: mobile})
user = User.objects.get(**{self.mobile_field_name: mobile2})
self.assertNotEqual(user, None)
self.assertNotEqual(getattr(user, self.mobile_field_name), None)
self.assertEqual(getattr(user, self.mobile_verified_field_name), True)
Expand Down

0 comments on commit 36d481e

Please sign in to comment.