diff --git a/jose/jwt.py b/jose/jwt.py index 1da109ca..156063bf 100644 --- a/jose/jwt.py +++ b/jose/jwt.py @@ -56,7 +56,7 @@ def encode(claims, key, algorithm=None, headers=None): return jws.sign(claims, key, headers=headers) -def decode(token, key, algorithms=None, options=None, audience=None, issuer=None): +def decode(token, key, algorithms=None, options=None, audience=None, issuer=None, subject=None): """Verifies a JWT string's signature and validates reserved claims. Args: @@ -69,6 +69,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None issuer (str): The issuer of the token. If the "iss" claim is included in the claim set, then the issuer must be included and must equal the provided claim. + subject (str): The subject of the token. If the "sub" claim is + included in the claim set, then the subject must be included and must equal + the provided claim. options (dict): A dictionary of options for skipping validation steps. defaults = { @@ -127,7 +130,7 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None if not isinstance(claims, Mapping): raise JWTError('Invalid payload string: must be a json object') - _validate_claims(claims, audience=audience, issuer=issuer, options=defaults) + _validate_claims(claims, audience=audience, issuer=issuer, subject=subject, options=defaults) return claims @@ -333,7 +336,7 @@ def _validate_iss(claims, issuer=None): raise JWTClaimsError('Invalid issuer') -def _validate_sub(claims): +def _validate_sub(claims, subject=None): """Validates that the 'sub' claim is valid. The "sub" (subject) claim identifies the principal that is the @@ -346,6 +349,7 @@ def _validate_sub(claims): Args: claims (dict): The claims dictionary to validate. + subject (str): The subject of the token. """ if 'sub' not in claims: @@ -354,6 +358,9 @@ def _validate_sub(claims): if not isinstance(claims['sub'], string_types): raise JWTClaimsError('Subject must be a string.') + if subject is not None: + if claims.get('sub') != subject: + raise JWTClaimsError('Invalid subject') def _validate_jti(claims): """Validates that the 'jti' claim is valid. @@ -377,7 +384,7 @@ def _validate_jti(claims): raise JWTClaimsError('JWT ID must be a string.') -def _validate_claims(claims, audience=None, issuer=None, options=None): +def _validate_claims(claims, audience=None, issuer=None, subject=None, options=None): leeway = options.get('leeway', 0) @@ -403,7 +410,7 @@ def _validate_claims(claims, audience=None, issuer=None, options=None): _validate_iss(claims, issuer=issuer) if options.get('verify_sub'): - _validate_sub(claims) + _validate_sub(claims, subject=subject) if options.get('verify_jti'): _validate_jti(claims) diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 5cd7a86f..8b4a8f67 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -380,6 +380,29 @@ def test_sub_invalid(self, key): with pytest.raises(JWTError): jwt.decode(token, key) + def test_sub_correct(self, key): + + sub = 'subject' + + claims = { + 'sub': sub + } + + token = jwt.encode(claims, key) + jwt.decode(token, key, subject=sub) + + def test_sub_incorrect(self, key): + + sub = 'subject' + + claims = { + 'sub': sub + } + + token = jwt.encode(claims, key) + with pytest.raises(JWTError): + jwt.decode(token, key, subject='another') + def test_jti_string(self, key): jti = 'JWT ID'