Skip to content

Commit

Permalink
Merge pull request mpdavis#30 from bjmc/at_hash
Browse files Browse the repository at this point in the history
Adds support for at_hash verification
  • Loading branch information
Michael Davis committed Jul 22, 2016
2 parents b7d3871 + 95fb84a commit cc402f8
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 9 deletions.
14 changes: 13 additions & 1 deletion jose/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import hashlib

class ALGORITHMS(object):
NONE = 'none'
Expand All @@ -19,3 +19,15 @@ class ALGORITHMS(object):
SUPPORTED = HMAC + RSA + EC

ALL = SUPPORTED + (NONE, )

HASHES = {
HS256: hashlib.sha256,
HS384: hashlib.sha384,
HS512: hashlib.sha512,
RS256: hashlib.sha256,
RS384: hashlib.sha384,
RS512: hashlib.sha512,
ES256: hashlib.sha256,
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}
69 changes: 61 additions & 8 deletions jose/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from .exceptions import JWTClaimsError
from .exceptions import JWTError
from .exceptions import ExpiredSignatureError
from .utils import timedelta_total_seconds
from .constants import ALGORITHMS
from .utils import timedelta_total_seconds, calculate_at_hash


def encode(claims, key, algorithm=None, headers=None):
def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=None):
"""Encodes a claims set and returns a JWT string.
JWTs are JWS signed objects with a few reserved claims.
Expand All @@ -30,6 +31,9 @@ def encode(claims, key, algorithm=None, headers=None):
headers (dict, optional): A set of headers that will be added to
the default headers. Any headers that are added as additional
headers will override the default headers.
access_token (str, optional): If present, the 'at_hash' claim will
be calculated and added to the claims present in the 'claims'
parameter.
Returns:
str: The string representation of the header, claims, and signature.
Expand All @@ -50,13 +54,15 @@ def encode(claims, key, algorithm=None, headers=None):
if isinstance(claims.get(time_claim), datetime):
claims[time_claim] = timegm(claims[time_claim].utctimetuple())

if algorithm:
return jws.sign(claims, key, headers=headers, algorithm=algorithm)
if access_token:
claims['at_hash'] = calculate_at_hash(access_token,
ALGORITHMS.HASHES[algorithm])

return jws.sign(claims, key, headers=headers)
return jws.sign(claims, key, headers=headers, algorithm=algorithm)


def decode(token, key, algorithms=None, options=None, audience=None, issuer=None, subject=None):
def decode(token, key, algorithms=None, options=None, audience=None,
issuer=None, subject=None, access_token=None):
"""Verifies a JWT string's signature and validates reserved claims.
Args:
Expand All @@ -72,6 +78,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
subject (str): The subject of the token. If the "sub" claim is
included in the claim set, then the subject must be included and must equal
the provided claim.
access_token (str): An access token returned alongside the id_token during
the authorization grant flow. If the "at_hash" claim is included in the
claim set, then the access_token must be included, and it must match
the "at_hash" claim.
options (dict): A dictionary of options for skipping validation steps.
defaults = {
Expand Down Expand Up @@ -109,6 +119,7 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'leeway': 0,
}

Expand All @@ -122,6 +133,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
except JWSError as e:
raise JWTError(e)

# Needed for at_hash verification
algorithm = jws.get_unverified_header(token)['alg']

try:
claims = json.loads(payload.decode('utf-8'))
except ValueError as e:
Expand All @@ -130,7 +144,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
if not isinstance(claims, Mapping):
raise JWTError('Invalid payload string: must be a json object')

_validate_claims(claims, audience=audience, issuer=issuer, subject=subject, options=defaults)
_validate_claims(claims, audience=audience, issuer=issuer,
subject=subject, algorithm=algorithm,
access_token=access_token,
options=defaults)

return claims

Expand Down Expand Up @@ -384,7 +401,40 @@ def _validate_jti(claims):
raise JWTClaimsError('JWT ID must be a string.')


def _validate_claims(claims, audience=None, issuer=None, subject=None, options=None):
def _validate_at_hash(claims, access_token, algorithm):
"""
Validates that the 'at_hash' parameter included in the claims matches
with the access_token returned alongside the id token as part of
the authorization_code flow.
Args:
claims (dict): The claims dictionary to validate.
access_token (str): The access token returned by the OpenID Provider.
algorithm (str): The algorithm used to sign the JWT, as specified by
the token headers.
"""
if 'at_hash' not in claims and not access_token:
return
elif 'at_hash' in claims and not access_token:
msg = 'No access_token provided to compare against at_hash claim.'
raise JWTClaimsError(msg)
elif access_token and 'at_hash' not in claims:
msg = 'at_hash claim missing from token.'
raise JWTClaimsError(msg)

try:
expected_hash = calculate_at_hash(access_token,
ALGORITHMS.HASHES[algorithm])
except (TypeError, ValueError):
msg = 'Unable to calculate at_hash to verify against token claims.'
raise JWTClaimsError(msg)

if claims['at_hash'] != expected_hash:
raise JWTClaimsError('at_hash claim does not match access_token.')


def _validate_claims(claims, audience=None, issuer=None, subject=None,
algorithm=None, access_token=None, options=None):

leeway = options.get('leeway', 0)

Expand Down Expand Up @@ -414,3 +464,6 @@ def _validate_claims(claims, audience=None, issuer=None, subject=None, options=N

if options.get('verify_jti'):
_validate_jti(claims)

if options.get('verify_at_hash'):
_validate_at_hash(claims, access_token, algorithm)
23 changes: 23 additions & 0 deletions jose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
import base64


def calculate_at_hash(access_token, hash_alg):
"""Helper method for calculating an access token
hash, as described in http:https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
Its value is the base64url encoding of the left-most half of the hash of the octets
of the ASCII representation of the access_token value, where the hash algorithm
used is the hash algorithm used in the alg Header Parameter of the ID Token's JOSE
Header. For instance, if the alg is RS256, hash the access_token value with SHA-256,
then take the left-most 128 bits and base64url encode them. The at_hash value is a
case sensitive string.
Args:
access_token (str): An access token string.
hash_alg (callable): A callable returning a hash object, e.g. hashlib.sha256
"""
hash_digest = hash_alg(access_token.encode('utf-8')).digest()
cut_at = int(len(hash_digest) / 2)
truncated = hash_digest[:cut_at]
at_hash = base64url_encode(truncated)
return at_hash.decode('utf-8')


def base64url_decode(input):
"""Helper method to base64url_decode a string.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,32 @@ def test_jti_invalid(self, key):
with pytest.raises(JWTError):
jwt.decode(token, key)

def test_at_hash(self, claims, key):
access_token = '<ACCESS_TOKEN>'
token = jwt.encode(claims, key, access_token=access_token)
payload = jwt.decode(token, key, access_token=access_token)
assert 'at_hash' in payload

def test_at_hash_invalid(self, claims, key):
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
with pytest.raises(JWTError):
jwt.decode(token, key, access_token='<OTHER_TOKEN>')

def test_at_hash_missing_access_token(self, claims, key):
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
with pytest.raises(JWTError):
jwt.decode(token, key)

def test_at_hash_missing_claim(self, claims, key):
token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, access_token='<ACCESS_TOKEN>')

def test_at_hash_unable_to_calculate(self, claims, key):
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
with pytest.raises(JWTError):
jwt.decode(token, key, access_token='\xe2')

def test_unverified_claims_string(self):
token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.aW52YWxpZCBjbGFpbQ.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck'
with pytest.raises(JWTError):
Expand Down

0 comments on commit cc402f8

Please sign in to comment.