From 06dfbcafec3704e4012e60f0eec0c6f81398d6d0 Mon Sep 17 00:00:00 2001 From: Sean Date: Fri, 25 Nov 2022 12:48:31 -0500 Subject: [PATCH] Allow creation of GivenKeys with a required algorithm To comply with RFC 8725 Section 3.1, we allow specifying an alg for GivenKeys. It behaves equivalently to the 'alg' parameter parsed from JWKS JSON. --- given.go | 19 ++++++++++++++++--- given_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++ jwks.go | 10 ++++++++++ jwks_test.go | 32 +++++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 3 deletions(-) diff --git a/given.go b/given.go index 6498431..facf093 100644 --- a/given.go +++ b/given.go @@ -8,7 +8,8 @@ import ( // GivenKey represents a cryptographic key that resides in a JWKS. In conjuncture with Options. type GivenKey struct { - inter interface{} + inter interface{} + algorithm string } // NewGiven creates a JWKS from a map of given keys. @@ -16,7 +17,7 @@ func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) { keys := make(map[string]parsedJWK) for kid, given := range givenKeys { - keys[kid] = parsedJWK{public: given.inter} + keys[kid] = parsedJWK{public: given.inter, algorithm: given.algorithm} } return &JWKS{ @@ -25,7 +26,7 @@ func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) { } // NewGivenCustom creates a new GivenKey given an untyped variable. The key argument is expected to be a supported -// by the jwt package used. +// by the jwt package used. To specify a required algorithm use NewGivenCustomAlg. // // See the https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod function for registering an unsupported // signing method. @@ -35,6 +36,18 @@ func NewGivenCustom(key interface{}) (givenKey GivenKey) { } } +// NewGivenCustomAlg creates a new GivenKey given an untyped variable and an algorithm. The key argument is expected to +// be a type supported by the jwt package used. The alg argument will be validated against the alg header of tokens. +// +// See the https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod function for registering an unsupported +// signing method. +func NewGivenCustomAlg(key interface{}, alg string) (givenKey GivenKey) { + return GivenKey{ + inter: key, + algorithm: alg, + } +} + // NewGivenECDSA creates a new GivenKey given an ECDSA public key. func NewGivenECDSA(key *ecdsa.PublicKey) (givenKey GivenKey) { return GivenKey{ diff --git a/given_test.go b/given_test.go index f25d533..7f8273b 100644 --- a/given_test.go +++ b/given_test.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" + "errors" "fmt" "testing" @@ -45,6 +46,57 @@ func TestNewGivenCustom(t *testing.T) { signParseValidate(t, token, key, jwks) } +// TestNewGivenCustomAlg tests that a custom jwt.SigningMethod can be used to create a JWKS and a proper jwt.Keyfunc. +func TestNewGivenCustomAlg(t *testing.T) { + jwt.RegisterSigningMethod(method.CustomAlg, func() jwt.SigningMethod { + return method.EmptyCustom{} + }) + + const key = "test-key" + givenKeys := make(map[string]keyfunc.GivenKey) + givenKeys[testKID] = keyfunc.NewGivenCustomAlg(key, method.CustomAlg) + + jwks := keyfunc.NewGiven(givenKeys) + + token := jwt.New(method.EmptyCustom{}) + token.Header[algAttribute] = method.CustomAlg + token.Header[kidAttribute] = testKID + + signParseValidate(t, token, key, jwks) +} + +// TestNewGivenCustomAlg_NegativeCase tests that a custom jwt.SigningMethod can be used to create +// a JWKS and a proper jwt.Keyfunc and that a token with a non-matching algorithm will be rejected. +func TestNewGivenCustomAlg_NegativeCase(t *testing.T) { + jwt.RegisterSigningMethod(method.CustomAlg, func() jwt.SigningMethod { + return method.EmptyCustom{} + }) + + const key = jwt.UnsafeAllowNoneSignatureType // So golang-jwt isn't the one blocking this test + givenKeys := make(map[string]keyfunc.GivenKey) + givenKeys[testKID] = keyfunc.NewGivenCustomAlg(key, method.CustomAlg) + + jwks := keyfunc.NewGiven(givenKeys) + + token := jwt.New(method.EmptyCustom{}) + token.Header[algAttribute] = jwt.SigningMethodNone.Alg() + token.Header[kidAttribute] = testKID + + jwtB64, err := token.SignedString(key) + if err != nil { + t.Fatalf(logFmt, "Failed to sign the JWT.", err) + } + + parsed, err := jwt.NewParser().Parse(jwtB64, jwks.Keyfunc) + if !errors.Is(err, keyfunc.ErrJWKAlgMismatch) { + t.Fatalf("Failed to return ErrJWKAlgMismatch") + } + + if parsed.Valid { + t.Fatalf("The JWT was valid.") + } +} + // TestNewGivenKeyECDSA tests that a generated ECDSA key can be added to the JWKS and create a proper jwt.Keyfunc. func TestNewGivenKeyECDSA(t *testing.T) { givenKeys := make(map[string]keyfunc.GivenKey) diff --git a/jwks.go b/jwks.go index a13bb38..b5a0d37 100644 --- a/jwks.go +++ b/jwks.go @@ -160,6 +160,16 @@ func (j *JWKS) KIDs() (kids []string) { return kids } +// KeyAlg returns the algorithm (`alg`) for the key identified by Key ID (`kid`). +func (j *JWKS) KeyAlg(kid string) string { + j.mux.RLock() + defer j.mux.RUnlock() + if pubKey, ok := j.keys[kid]; ok { + return pubKey.algorithm + } + return "" +} + // Len returns the number of keys in the JWKS. func (j *JWKS) Len() int { j.mux.RLock() diff --git a/jwks_test.go b/jwks_test.go index a0b6b77..073ac2a 100644 --- a/jwks_test.go +++ b/jwks_test.go @@ -296,6 +296,38 @@ func TestJWKS_Len(t *testing.T) { } } +// TestJWKS_KeyAlg confirms the JWKS.Len returns the algorithm for keys by kid. +func TestJWKS_KeyAlg(t *testing.T) { + jwks, err := keyfunc.NewJSON([]byte(jwksJSON)) + if err != nil { + t.Fatalf(logFmt, "Failed to create a JWKS from JSON.", err) + } + + expectedAlgs := map[string]string{ + "zXew0UJ1h6Q4CCcd_9wxMzvcp5cEBifH0KWrCz2Kyxc": "PS256", + "ebJxnm9B3QDBljB5XJWEu72qx6BawDaMAhwz4aKPkQ0": "ES512", + "TVAAet63O3xy_KK6_bxVIu7Ra3_z1wlB543Fbwi5VaU": "ES384", + "arlUxX4hh56rNO-XdIPhDT7bqBMqcBwNQuP_TnZJNGs": "RS512", + "tW6ae7TomE6_2jooM-sf9N_6lWg7HNtaQXrDsElBzM4": "PS512", + "Lx1FmayP2YBtxaqS1SKJRJGiXRKnw2ov5WmYIMG-BLE": "PS384", + "gnmAfvmlsi3kKH3VlM1AJ85P2hekQ8ON_XvJqs3xPD8": "RS384", + "CGt0ZWS4Lc5faiKSdi0tU0fjCAdvGROQRGU9iR7tV0A": "ES256", + "C65q0EKQyhpd1m4fr7SKO2He_nAxgCtAdws64d2BLt8": "RS256", + "Q56A": "", + "hmac": "", + "WW91IGdldCBhIGdvbGQgc3RhciDwn4yfCg": "", + } + + for kid, expectedAlg := range expectedAlgs { + t.Run(kid, func(t *testing.T) { + actualAlg := jwks.KeyAlg(kid) + if actualAlg != expectedAlg { + t.Errorf("Unexpected alg for key %v.\n Expected: %v\n Actual: %v\n", kid, expectedAlg, actualAlg) + } + }) + } +} + // TestRateLimit performs a test to confirm the rate limiter works as expected. func TestRateLimit(t *testing.T) { tempDir, err := os.MkdirTemp("", "*")