Skip to content

Commit

Permalink
Adds support for at_hash verification
Browse files Browse the repository at this point in the history
  • Loading branch information
bjmc committed Jul 18, 2016
1 parent 048377d commit 0acb22d
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 9 deletions.
14 changes: 13 additions & 1 deletion jose/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import hashlib

class ALGORITHMS(object):
NONE = 'none'
Expand All @@ -19,3 +19,15 @@ class ALGORITHMS(object):
SUPPORTED = HMAC + RSA + EC

ALL = SUPPORTED + (NONE, )

HASHES = {
HS256: hashlib.sha256,
HS384: hashlib.sha384,
HS512: hashlib.sha512,
RS256: hashlib.sha256,
RS384: hashlib.sha384,
RS512: hashlib.sha512,
ES256: hashlib.sha256,
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}
69 changes: 61 additions & 8 deletions jose/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from .exceptions import JWTClaimsError
from .exceptions import JWTError
from .exceptions import ExpiredSignatureError
from .utils import timedelta_total_seconds
from .constants import ALGORITHMS
from .utils import timedelta_total_seconds, calculate_at_hash


def encode(claims, key, algorithm=None, headers=None):
def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=None):
"""Encodes a claims set and returns a JWT string.
JWTs are JWS signed objects with a few reserved claims.
Expand All @@ -30,6 +31,9 @@ def encode(claims, key, algorithm=None, headers=None):
headers (dict, optional): A set of headers that will be added to
the default headers. Any headers that are added as additional
headers will override the default headers.
access_token (str, optional): If present, the 'at_hash' claim will
be calculated and added to the claims present in the 'claims'
parameter.
Returns:
str: The string representation of the header, claims, and signature.
Expand All @@ -50,13 +54,15 @@ def encode(claims, key, algorithm=None, headers=None):
if isinstance(claims.get(time_claim), datetime):
claims[time_claim] = timegm(claims[time_claim].utctimetuple())

if algorithm:
return jws.sign(claims, key, headers=headers, algorithm=algorithm)
if access_token:
claims['at_hash'] = calculate_at_hash(access_token,
ALGORITHMS.HASHES[algorithm])

return jws.sign(claims, key, headers=headers)
return jws.sign(claims, key, headers=headers, algorithm=algorithm)


def decode(token, key, algorithms=None, options=None, audience=None, issuer=None, subject=None):
def decode(token, key, algorithms=None, options=None, audience=None,
issuer=None, subject=None, access_token=None):
"""Verifies a JWT string's signature and validates reserved claims.
Args:
Expand All @@ -72,6 +78,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
subject (str): The subject of the token. If the "sub" claim is
included in the claim set, then the subject must be included and must equal
the provided claim.
access_token (str): An access token returned alongside the id_token during
the authorization grant flow. If the "at_hash" claim is included in the
claim set, then the access_token must be included, and it must match
the "at_hash" claim.
options (dict): A dictionary of options for skipping validation steps.
defaults = {
Expand Down Expand Up @@ -109,6 +119,7 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'leeway': 0,
}

Expand All @@ -122,6 +133,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
except JWSError as e:
raise JWTError(e)

# Needed for at_hash verification
algorithm = jws.get_unverified_header(token)['alg']

try:
claims = json.loads(payload.decode('utf-8'))
except ValueError as e:
Expand All @@ -130,7 +144,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
if not isinstance(claims, Mapping):
raise JWTError('Invalid payload string: must be a json object')

_validate_claims(claims, audience=audience, issuer=issuer, subject=subject, options=defaults)
_validate_claims(claims, audience=audience, issuer=issuer,
subject=subject, algorithm=algorithm,
access_token=access_token,
options=defaults)

return claims

Expand Down Expand Up @@ -384,7 +401,40 @@ def _validate_jti(claims):
raise JWTClaimsError('JWT ID must be a string.')


def _validate_claims(claims, audience=None, issuer=None, subject=None, options=None):
def _validate_at_hash(claims, access_token, algorithm):
"""
Validates that the 'at_hash' parameter included in the claims matches
with the access_token returned alongside the id token as part of
the authorization_code flow.
Args:
claims (dict): The claims dictionary to validate.
access_token (str): The access token returned by the OpenID Provider.
algorithm (str): The algorithm used to sign the JWT, as specified by
the token headers.
"""
if 'at_hash' not in claims and not access_token:
return
elif 'at_hash' in claims and not access_token:
msg = 'No access_token provided to compare against at_hash claim.'
raise JWTClaimsError(msg)
elif access_token and 'at_hash' not in claims:
msg = 'at_hash claim missing from token.'
raise JWTClaimsError(msg)

try:
expected_hash = calculate_at_hash(access_token,
ALGORITHMS.HASHES[algorithm])
except TypeError:
msg = 'Unable to calculate at_hash to verify against token claims.'
raise JWTClaimsError(msg)

if claims['at_hash'] != expected_hash:
raise JWTClaimsError('at_hash claim does not match access_token.')


def _validate_claims(claims, audience=None, issuer=None, subject=None,
algorithm=None, access_token=None, options=None):

leeway = options.get('leeway', 0)

Expand Down Expand Up @@ -414,3 +464,6 @@ def _validate_claims(claims, audience=None, issuer=None, subject=None, options=N

if options.get('verify_jti'):
_validate_jti(claims)

if options.get('verify_at_hash'):
_validate_at_hash(claims, access_token, algorithm)
23 changes: 23 additions & 0 deletions jose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
import base64


def calculate_at_hash(access_token, hash_alg):
"""Helper method for calculating an access token
hash, as described in https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
Its value is the base64url encoding of the left-most half of the hash of the octets
of the ASCII representation of the access_token value, where the hash algorithm
used is the hash algorithm used in the alg Header Parameter of the ID Token's JOSE
Header. For instance, if the alg is RS256, hash the access_token value with SHA-256,
then take the left-most 128 bits and base64url encode them. The at_hash value is a
case sensitive string.
Args:
access_token (str): An access token string.
hash_alg (callable): A callable returning a hash object, e.g. hashlib.sha256
"""
hash_digest = hash_alg(access_token.encode('utf-8')).digest()
cut_at = int(len(hash_digest) / 2)
truncated = hash_digest[:cut_at]
at_hash = base64url_encode(truncated)
return at_hash.decode('utf-8')


def base64url_decode(input):
"""Helper method to base64url_decode a string.
Expand Down

0 comments on commit 0acb22d

Please sign in to comment.