diff --git a/README.md b/README.md index 2fb1a21..4a95ebe 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ jwksURL := os.Getenv("JWKS_URL") // Confirm the environment variable is not empty. if jwksURL == "" { -log.Fatalln("JWKS_URL environment variable must be populated.") + log.Fatalln("JWKS_URL environment variable must be populated.") } ``` @@ -81,7 +81,7 @@ Via HTTP: // Create the JWKS from the resource at the given URL. jwks, err := keyfunc.Get(jwksURL, keyfunc.Options{}) // See recommended options in the examples directory. if err != nil { -log.Fatalf("Failed to get the JWKS from the given URL.\nError: %s", err) + log.Fatalf("Failed to get the JWKS from the given URL.\nError: %s", err) } ``` Via JSON: @@ -92,7 +92,7 @@ var jwksJSON = json.RawMessage(`{"keys":[{"kid":"zXew0UJ1h6Q4CCcd_9wxMzvcp5cEBif // Create the JWKS from the resource at the given URL. jwks, err := keyfunc.NewJSON(jwksJSON) if err != nil { -log.Fatalf("Failed to create JWKS from JSON.\nError: %s", err) + log.Fatalf("Failed to create JWKS from JSON.\nError: %s", err) } ``` Via a given key: @@ -103,7 +103,7 @@ uniqueKeyID := "myKeyID" // Create the JWKS from the HMAC key. jwks := keyfunc.NewGiven(map[string]keyfunc.GivenKey{ -uniqueKeyID: keyfunc.NewGivenHMAC(key), + uniqueKeyID: keyfunc.NewGivenHMAC(key), }) ``` @@ -117,7 +117,7 @@ features mentioned at the bottom of this `README.md`. // Parse the JWT. token, err := jwt.Parse(jwtB64, jwks.Keyfunc) if err != nil { -return nil, fmt.Errorf("failed to parse token: %w", err) + return nil, fmt.Errorf("failed to parse token: %w", err) } ``` @@ -180,6 +180,11 @@ base64url the same as RFC 7515 Section 2: However, this package will remove trailing padding on base64url encoded keys to account for improper implementations of JWKS. +This package will check the `alg` in each JWK. If present, it will confirm the same `alg` is in a given JWT's header +before returning the key for signature verification. If the `alg`s do not match, `keyfunc.ErrJWKAlgMismatch` will +prevent the key being used for signature verification. If the `alg` is not present in the JWK, this check will not +occur. + ## References This project was built and tested using various RFCs and services. The services are listed below: * [Keycloak](https://www.keycloak.org/) diff --git a/alg_test.go b/alg_test.go new file mode 100644 index 0000000..9fafae1 --- /dev/null +++ b/alg_test.go @@ -0,0 +1,25 @@ +package keyfunc_test + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/golang-jwt/jwt/v4" + + "github.com/MicahParks/keyfunc" +) + +func TestAlgMismatch(t *testing.T) { + const jwtB64 = "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCIsImtpZCI6IkM2NXEwRUtReWhwZDFtNGZyN1NLTzJIZV9uQXhnQ3RBZHdzNjRkMkJMdDgifQ.eyJleHAiOjE2MTU0MDcwMjYsImlhdCI6MTYxNTQwNjk2NiwianRpIjoiMzg1NjE4ODItOTA5MS00ODY3LTkzYmYtMmE3YmU4NTc3YmZiIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwL2F1dGgvcmVhbG1zL21hc3RlciIsImF1ZCI6ImFjY291bnQiLCJzdWIiOiJhZDEyOGRmMS0xMTQwLTRlNGMtYjA5Ny1hY2RjZTcwNWJkOWIiLCJ0eXAiOiJCZWFyZXIiLCJhenAiOiJ0b2tlbmRlbG1lIiwiYWNyIjoiMSIsInJlYWxtX2FjY2VzcyI6eyJyb2xlcyI6WyJvZmZsaW5lX2FjY2VzcyIsInVtYV9hdXRob3JpemF0aW9uIl19LCJyZXNvdXJjZV9hY2Nlc3MiOnsiYWNjb3VudCI6eyJyb2xlcyI6WyJtYW5hZ2UtYWNjb3VudCIsIm1hbmFnZS1hY2NvdW50LWxpbmtzIiwidmlldy1wcm9maWxlIl19fSwic2NvcGUiOiJlbWFpbCBwcm9maWxlIiwiY2xpZW50SG9zdCI6IjE3Mi4yMC4wLjEiLCJjbGllbnRJZCI6InRva2VuZGVsbWUiLCJlbWFpbF92ZXJpZmllZCI6ZmFsc2UsInByZWZlcnJlZF91c2VybmFtZSI6InNlcnZpY2UtYWNjb3VudC10b2tlbmRlbG1lIiwiY2xpZW50QWRkcmVzcyI6IjE3Mi4yMC4wLjEifQ.Cmgz3aC_b_kpOmGM-_nRisgQul0d9Jg7BpMLe5F_fdryRhwhW5fQBZtz6FipQ0Tc4jggI6L3Dx1jS2kn823aWCR0x-OAFCawIXnwgAKuM1m2NL7Y6LKC07nytdB_qU4GknAl3jEG-tZIJBHQwYP-K6QKmAT9CdF1ZPbc9u8RgRCPN8UziYcOpvStiG829BO7cTzCt7tp5dJhem8_CnRWBKzelP1fs_z4fAQtW2sgyhX9SUYb5WON-4zrn4i01FlYUwZV-AC83zP6BuHIiy3XpAuTiTp2BjZ-1nzCLWBRpIm_lOObFeo-3AQqWPxzLVAmTFQMKReUF9T8ehL2Osr1XQ" + + jwks, err := keyfunc.NewJSON(json.RawMessage(jwksJSON)) + if err != nil { + t.Fatalf("Failed to create JWKS from JSON: %v", err) + } + + _, err = jwt.Parse(jwtB64, jwks.Keyfunc) + if !errors.Is(err, keyfunc.ErrJWKAlgMismatch) { + t.Fatalf("Expected ErrJWKAlgMismatch, got %v", err) + } +} diff --git a/ecdsa.go b/ecdsa.go index 1105623..ca0566d 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -3,6 +3,7 @@ package keyfunc import ( "crypto/ecdsa" "crypto/elliptic" + "errors" "fmt" "math/big" ) @@ -21,6 +22,11 @@ const ( p521 = "P-521" ) +var ( + // ErrECDSACurve indicates an error with the ECDSA curve. + ErrECDSACurve = errors.New("invalid ECDSA curve") +) + // ECDSA parses a jsonWebKey and turns it into an ECDSA public key. func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) { if j.X == "" || j.Y == "" || j.Curve == "" { @@ -48,6 +54,8 @@ func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) { publicKey.Curve = elliptic.P384() case p521: publicKey.Curve = elliptic.P521() + default: + return nil, fmt.Errorf("%w: unknown curve: %s", ErrECDSACurve, j.Curve) } // Turn the X coordinate into *big.Int. diff --git a/ecdsa_test.go b/ecdsa_test.go new file mode 100644 index 0000000..c71cacb --- /dev/null +++ b/ecdsa_test.go @@ -0,0 +1,30 @@ +package keyfunc_test + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/golang-jwt/jwt/v4" + + "github.com/MicahParks/keyfunc" +) + +func TestBadCurve(t *testing.T) { + const ( + badJWKS = `{"keys":[{"kty":"EC","crv":"BAD","x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4","y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM","use":"sig","kid":"1"}]}` + someJWT = `eyJhbGciOiJFUzI1NiIsImtpZCI6IjEiLCJ0eXAiOiJKV1QifQ.e30.Q1EeyWUv6XEA0gMLwTFoNhx7Hq1MbVwjI2k9FZPSa-myKW1wYn1X6rHtRyuV-2MEzvimCskFD-afL7UzvdWBQg` + ) + + jwks, err := keyfunc.NewJSON(json.RawMessage(badJWKS)) + if err != nil { + t.Fatalf("Failed to create JWKS from JSON: %v", err) + } + + // The number of parsed keys should be 0. + + _, err = jwt.Parse(someJWT, jwks.Keyfunc) + if !errors.Is(err, keyfunc.ErrKIDNotFound) { + t.Fatalf("Expected ErrKIDNotFound, got %v", err) + } +} diff --git a/examples/custom/main.go b/examples/custom/main.go index 19173e2..1f09865 100644 --- a/examples/custom/main.go +++ b/examples/custom/main.go @@ -15,7 +15,7 @@ func main() { const exampleKID = "exampleKeyID" // Register the custom signing method. - jwt.RegisterSigningMethod(method.CustomAlg, func() jwt.SigningMethod { + jwt.RegisterSigningMethod(method.CustomAlgHeader, func() jwt.SigningMethod { return method.EmptyCustom{} }) @@ -29,7 +29,9 @@ func main() { // Create the JWKS from the given signing method's key. jwks := keyfunc.NewGiven(map[string]keyfunc.GivenKey{ - exampleKID: keyfunc.NewGivenCustom(key), + exampleKID: keyfunc.NewGivenCustomWithOptions(key, keyfunc.GivenKeyOptions{ + Algorithm: method.CustomAlgHeader, + }), }) // Parse the token. diff --git a/examples/custom/method/method.go b/examples/custom/method/method.go index 4d4ddd9..7e03ab5 100644 --- a/examples/custom/method/method.go +++ b/examples/custom/method/method.go @@ -1,7 +1,7 @@ package method -// CustomAlg is the `alg` JSON attribute's value for the example custom jwt.SigningMethod. -const CustomAlg = "customalg" +// CustomAlgHeader is the `alg` JSON attribute's value for the example custom jwt.SigningMethod. +const CustomAlgHeader = "customalg" // EmptyCustom implements the jwt.SigningMethod interface. It will not sign or verify anything. type EmptyCustom struct{} @@ -13,10 +13,10 @@ func (e EmptyCustom) Verify(_, _ string, _ interface{}) error { // Sign helps implement the jwt.SigningMethod interface. It does not sign anything. func (e EmptyCustom) Sign(_ string, _ interface{}) (string, error) { - return CustomAlg, nil + return CustomAlgHeader, nil } // Alg helps implement the jwt.SigningMethod. It returns the `alg` JSON attribute for JWTs signed with this method. func (e EmptyCustom) Alg() string { - return CustomAlg + return CustomAlgHeader } diff --git a/examples/given/main.go b/examples/given/main.go index 33db993..2787b0c 100644 --- a/examples/given/main.go +++ b/examples/given/main.go @@ -23,7 +23,9 @@ func main() { hmacSecret := []byte("example secret") const givenKID = "givenKID" givenKeys := map[string]keyfunc.GivenKey{ - givenKID: keyfunc.NewGivenHMAC(hmacSecret), + givenKID: keyfunc.NewGivenHMACCustomWithOptions(hmacSecret, keyfunc.GivenKeyOptions{ + Algorithm: jwt.SigningMethodHS256.Alg(), + }), } // Create the keyfunc options. Use an error handler that logs. Refresh the JWKS when a JWT signed by an unknown KID diff --git a/examples/hmac/main.go b/examples/hmac/main.go index e3387b4..79eba13 100644 --- a/examples/hmac/main.go +++ b/examples/hmac/main.go @@ -23,7 +23,9 @@ func main() { // Create the JWKS from the HMAC key. jwks := keyfunc.NewGiven(map[string]keyfunc.GivenKey{ - exampleKID: keyfunc.NewGivenHMAC(key), + exampleKID: keyfunc.NewGivenHMACCustomWithOptions(key, keyfunc.GivenKeyOptions{ + Algorithm: jwt.SigningMethodHS512.Alg(), + }), }) // Parse the token. diff --git a/given.go b/given.go index 6498431..68c8abd 100644 --- a/given.go +++ b/given.go @@ -4,11 +4,26 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "encoding/json" ) // GivenKey represents a cryptographic key that resides in a JWKS. In conjuncture with Options. type GivenKey struct { - inter interface{} + algorithm string + inter interface{} +} + +// GivenKeyOptions represents the configuration options for a GivenKey. +type GivenKeyOptions struct { + // Algorithm is the given key's signing algorithm. Its value will be compared to unverified tokens' "alg" header. + // + // See RFC 8725 Section 3.1 for details. + // https://www.rfc-editor.org/rfc/rfc8725#section-3.1 + // + // For a list of possible values, please see: + // https://www.rfc-editor.org/rfc/rfc7518#section-3.1 + // https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms + Algorithm string } // NewGiven creates a JWKS from a map of given keys. @@ -16,7 +31,10 @@ func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) { keys := make(map[string]parsedJWK) for kid, given := range givenKeys { - keys[kid] = parsedJWK{public: given.inter} + keys[kid] = parsedJWK{ + algorithm: given.algorithm, + public: given.inter, + } } return &JWKS{ @@ -29,36 +47,123 @@ func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) { // // See the https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod function for registering an unsupported // signing method. +// +// Deprecated: This function does not allow the user to specify the JWT's signing algorithm. Use +// NewGivenCustomWithOptions instead. func NewGivenCustom(key interface{}) (givenKey GivenKey) { return GivenKey{ inter: key, } } +// NewGivenCustomWithOptions creates a new GivenKey given an untyped variable. The key argument is expected to be a type +// supported by the jwt package used. +// +// Consider the options carefully as each field may have a security implication. +// +// See the https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod function for registering an unsupported +// signing method. +func NewGivenCustomWithOptions(key interface{}, options GivenKeyOptions) (givenKey GivenKey) { + return GivenKey{ + algorithm: options.Algorithm, + inter: key, + } +} + // NewGivenECDSA creates a new GivenKey given an ECDSA public key. +// +// Deprecated: This function does not allow the user to specify the JWT's signing algorithm. Use +// NewGivenECDSACustomWithOptions instead. func NewGivenECDSA(key *ecdsa.PublicKey) (givenKey GivenKey) { return GivenKey{ inter: key, } } +// NewGivenECDSACustomWithOptions creates a new GivenKey given an ECDSA public key. +// +// Consider the options carefully as each field may have a security implication. +func NewGivenECDSACustomWithOptions(key *ecdsa.PublicKey, options GivenKeyOptions) (givenKey GivenKey) { + return GivenKey{ + algorithm: options.Algorithm, + inter: key, + } +} + // NewGivenEdDSA creates a new GivenKey given an EdDSA public key. +// +// Deprecated: This function does not allow the user to specify the JWT's signing algorithm. Use +// NewGivenEdDSACustomWithOptions instead. func NewGivenEdDSA(key ed25519.PublicKey) (givenKey GivenKey) { return GivenKey{ inter: key, } } +// NewGivenEdDSACustomWithOptions creates a new GivenKey given an EdDSA public key. +// +// Consider the options carefully as each field may have a security implication. +func NewGivenEdDSACustomWithOptions(key ed25519.PublicKey, options GivenKeyOptions) (givenKey GivenKey) { + return GivenKey{ + algorithm: options.Algorithm, + inter: key, + } +} + // NewGivenHMAC creates a new GivenKey given an HMAC key in a byte slice. +// +// Deprecated: This function does not allow the user to specify the JWT's signing algorithm. Use +// NewGivenHMACCustomWithOptions instead. func NewGivenHMAC(key []byte) (givenKey GivenKey) { return GivenKey{ inter: key, } } +// NewGivenHMACCustomWithOptions creates a new GivenKey given an HMAC key in a byte slice. +// +// Consider the options carefully as each field may have a security implication. +func NewGivenHMACCustomWithOptions(key []byte, options GivenKeyOptions) (givenKey GivenKey) { + return GivenKey{ + algorithm: options.Algorithm, + inter: key, + } +} + // NewGivenRSA creates a new GivenKey given an RSA public key. +// +// Deprecated: This function does not allow the user to specify the JWT's signing algorithm. Use +// NewGivenRSACustomWithOptions instead. func NewGivenRSA(key *rsa.PublicKey) (givenKey GivenKey) { return GivenKey{ inter: key, } } + +// NewGivenRSACustomWithOptions creates a new GivenKey given an RSA public key. +// +// Consider the options carefully as each field may have a security implication. +func NewGivenRSACustomWithOptions(key *rsa.PublicKey, options GivenKeyOptions) (givenKey GivenKey) { + return GivenKey{ + algorithm: options.Algorithm, + inter: key, + } +} + +// NewGivenKeysFromJSON parses a raw JSON message into a map of key IDs (`kid`) to GivenKeys. The returned map is +// suitable for passing to `NewGiven()` or as `Options.GivenKeys` to `Get()` +func NewGivenKeysFromJSON(jwksBytes json.RawMessage) (map[string]GivenKey, error) { + // Parse by making a temporary JWKS instance. No need to lock its map since it doesn't escape this function. + j, err := NewJSON(jwksBytes) + if err != nil { + return nil, err + } + keys := make(map[string]GivenKey, len(j.keys)) + for kid, cryptoKey := range j.keys { + keys[kid] = GivenKey{ + algorithm: cryptoKey.algorithm, + inter: cryptoKey.public, + } + } + return keys, nil +} diff --git a/given_test.go b/given_test.go index f25d533..91e4e88 100644 --- a/given_test.go +++ b/given_test.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" + "errors" "fmt" "testing" @@ -29,7 +30,7 @@ const ( // TestNewGivenCustom tests that a custom jwt.SigningMethod can be used to create a JWKS and a proper jwt.Keyfunc. func TestNewGivenCustom(t *testing.T) { - jwt.RegisterSigningMethod(method.CustomAlg, func() jwt.SigningMethod { + jwt.RegisterSigningMethod(method.CustomAlgHeader, func() jwt.SigningMethod { return method.EmptyCustom{} }) @@ -39,12 +40,67 @@ func TestNewGivenCustom(t *testing.T) { jwks := keyfunc.NewGiven(givenKeys) token := jwt.New(method.EmptyCustom{}) - token.Header[algAttribute] = method.CustomAlg + token.Header[algAttribute] = method.CustomAlgHeader token.Header[kidAttribute] = testKID signParseValidate(t, token, key, jwks) } +// TestNewGivenCustomAlg tests that a custom jwt.SigningMethod can be used to create a JWKS and a proper jwt.Keyfunc. +func TestNewGivenCustomAlg(t *testing.T) { + jwt.RegisterSigningMethod(method.CustomAlgHeader, func() jwt.SigningMethod { + return method.EmptyCustom{} + }) + + const key = "test-key" + givenKeys := make(map[string]keyfunc.GivenKey) + givenKeys[testKID] = keyfunc.NewGivenCustomWithOptions(key, keyfunc.GivenKeyOptions{ + Algorithm: method.CustomAlgHeader, + }) + + jwks := keyfunc.NewGiven(givenKeys) + + token := jwt.New(method.EmptyCustom{}) + token.Header[algAttribute] = method.CustomAlgHeader + token.Header[kidAttribute] = testKID + + signParseValidate(t, token, key, jwks) +} + +// TestNewGivenCustomAlg_NegativeCase tests that a custom jwt.SigningMethod can be used to create +// a JWKS and a proper jwt.Keyfunc and that a token with a non-matching algorithm will be rejected. +func TestNewGivenCustomAlg_NegativeCase(t *testing.T) { + jwt.RegisterSigningMethod(method.CustomAlgHeader, func() jwt.SigningMethod { + return method.EmptyCustom{} + }) + + const key = jwt.UnsafeAllowNoneSignatureType // Allow the "none" JWT "alg" header value for golang-jwt. + givenKeys := make(map[string]keyfunc.GivenKey) + givenKeys[testKID] = keyfunc.NewGivenCustomWithOptions(key, keyfunc.GivenKeyOptions{ + Algorithm: method.CustomAlgHeader, + }) + + jwks := keyfunc.NewGiven(givenKeys) + + token := jwt.New(method.EmptyCustom{}) + token.Header[algAttribute] = jwt.SigningMethodNone.Alg() + token.Header[kidAttribute] = testKID + + jwtB64, err := token.SignedString(key) + if err != nil { + t.Fatalf(logFmt, "Failed to sign the JWT.", err) + } + + parsed, err := jwt.NewParser().Parse(jwtB64, jwks.Keyfunc) + if !errors.Is(err, keyfunc.ErrJWKAlgMismatch) { + t.Fatalf("Failed to return ErrJWKAlgMismatch: %v.", err) + } + + if parsed.Valid { + t.Fatalf("The JWT was valid.") + } +} + // TestNewGivenKeyECDSA tests that a generated ECDSA key can be added to the JWKS and create a proper jwt.Keyfunc. func TestNewGivenKeyECDSA(t *testing.T) { givenKeys := make(map[string]keyfunc.GivenKey) @@ -109,10 +165,50 @@ func TestNewGivenKeyRSA(t *testing.T) { signParseValidate(t, token, key, jwks) } +// TestNewGivenKeysFromJSON tests that parsing GivenKeys from JSON can be used to create a JWKS and a proper jwt.Keyfunc. +func TestNewGivenKeysFromJSON(t *testing.T) { + // Construct a JWKS JSON containing a JWK for which we know the private key and thus can sign a token. + key := []byte("test-hmac-secret") + const testJSON = `{ + "keys": [ + { + "kid": "testkid", + "kty": "oct", + "alg": "HS256", + "use": "sig", + "k": "dGVzdC1obWFjLXNlY3JldA" + } + ] + }` + + givenKeys, err := keyfunc.NewGivenKeysFromJSON([]byte(testJSON)) + if err != nil { + t.Fatalf(logFmt, "Failed to parse given keys from JSON.", err) + } + + jwks := keyfunc.NewGiven(givenKeys) + + token := jwt.New(jwt.SigningMethodHS256) + token.Header[kidAttribute] = testKID + + signParseValidate(t, token, key, jwks) +} + +// TestNewGivenKeysFromJSON_BadParse makes sure bad JSON returns an error. +func TestNewGivenKeysFromJSON_BadParse(t *testing.T) { + const testJSON = "{not the best syntax" + _, err := keyfunc.NewGivenKeysFromJSON([]byte(testJSON)) + if err == nil { + t.Fatalf("Expected a JSON parse error") + } +} + // addCustom adds a new key wto the given keys map. The new key is using a test jwt.SigningMethod. func addCustom(givenKeys map[string]keyfunc.GivenKey, kid string) (key string) { key = "" - givenKeys[kid] = keyfunc.NewGivenCustom(key) + givenKeys[kid] = keyfunc.NewGivenCustomWithOptions(key, keyfunc.GivenKeyOptions{ + Algorithm: method.CustomAlgHeader, + }) return key } @@ -123,7 +219,9 @@ func addECDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *ecdsa.Pri return nil, fmt.Errorf("failed to create ECDSA key: %w", err) } - givenKeys[kid] = keyfunc.NewGivenECDSA(&key.PublicKey) + givenKeys[kid] = keyfunc.NewGivenECDSACustomWithOptions(&key.PublicKey, keyfunc.GivenKeyOptions{ + Algorithm: jwt.SigningMethodES256.Alg(), + }) return key, nil } @@ -135,7 +233,9 @@ func addEdDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key ed25519.Pr return nil, fmt.Errorf("failed to create ECDSA key: %w", err) } - givenKeys[kid] = keyfunc.NewGivenEdDSA(pub) + givenKeys[kid] = keyfunc.NewGivenEdDSACustomWithOptions(pub, keyfunc.GivenKeyOptions{ + Algorithm: jwt.SigningMethodEdDSA.Alg(), + }) return key, nil } @@ -148,7 +248,9 @@ func addHMAC(givenKeys map[string]keyfunc.GivenKey, kid string) (secret []byte, return nil, fmt.Errorf("failed to create HMAC secret: %w", err) } - givenKeys[kid] = keyfunc.NewGivenHMAC(secret) + givenKeys[kid] = keyfunc.NewGivenHMACCustomWithOptions(secret, keyfunc.GivenKeyOptions{ + Algorithm: jwt.SigningMethodHS256.Alg(), + }) return secret, nil } @@ -160,7 +262,9 @@ func addRSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *rsa.Private return nil, fmt.Errorf("failed to create RSA key: %w", err) } - givenKeys[kid] = keyfunc.NewGivenRSA(&key.PublicKey) + givenKeys[kid] = keyfunc.NewGivenRSACustomWithOptions(&key.PublicKey, keyfunc.GivenKeyOptions{ + Algorithm: jwt.SigningMethodRS256.Alg(), + }) return key, nil } diff --git a/jwks.go b/jwks.go index 9a429bf..a13bb38 100644 --- a/jwks.go +++ b/jwks.go @@ -11,6 +11,10 @@ import ( ) var ( + // ErrJWKAlgMismatch indicates that the given JWK was found, but its "alg" parameter's value did not match that of + // the JWT. + ErrJWKAlgMismatch = errors.New(`the given JWK was found, but its "alg" parameter's value did not match the expected algorithm`) + // ErrJWKUseWhitelist indicates that the given JWK was found, but its "use" parameter's value was not whitelisted. ErrJWKUseWhitelist = errors.New(`the given JWK was found, but its "use" parameter's value was not whitelisted`) @@ -39,21 +43,23 @@ type JWKUse string // jsonWebKey represents a JSON Web Key inside a JWKS. type jsonWebKey struct { - Curve string `json:"crv"` - Exponent string `json:"e"` - K string `json:"k"` - ID string `json:"kid"` - Modulus string `json:"n"` - Type string `json:"kty"` - Use string `json:"use"` - X string `json:"x"` - Y string `json:"y"` + Algorithm string `json:"alg"` + Curve string `json:"crv"` + Exponent string `json:"e"` + K string `json:"k"` + ID string `json:"kid"` + Modulus string `json:"n"` + Type string `json:"kty"` + Use string `json:"use"` + X string `json:"x"` + Y string `json:"y"` } // parsedJWK represents a JSON Web Key parsed with fields as the correct Go types. type parsedJWK struct { - use JWKUse - public interface{} + algorithm string + public interface{} + use JWKUse } // JWKS represents a JSON Web Key Set (JWK Set). @@ -124,8 +130,9 @@ func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) { } jwks.keys[key.ID] = parsedJWK{ - use: JWKUse(key.Use), - public: keyInter, + algorithm: key.Algorithm, + use: JWKUse(key.Use), + public: keyInter, } } @@ -181,7 +188,7 @@ func (j *JWKS) ReadOnlyKeys() map[string]interface{} { } // getKey gets the jsonWebKey from the given KID from the JWKS. It may refresh the JWKS if configured to. -func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) { +func (j *JWKS) getKey(alg, kid string) (jsonKey interface{}, err error) { j.mux.RLock() pubKey, ok := j.keys[kid] j.mux.RUnlock() @@ -221,5 +228,9 @@ func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) { } } + if pubKey.algorithm != "" && pubKey.algorithm != alg { + return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrJWKAlgMismatch, pubKey.algorithm, alg) + } + return pubKey.public, nil } diff --git a/keyfunc.go b/keyfunc.go index f74f2fa..967c5a9 100644 --- a/keyfunc.go +++ b/keyfunc.go @@ -25,7 +25,14 @@ func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) { return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID) } - return j.getKey(kid) + alg, ok := token.Header["alg"].(string) + if !ok { + // For test coverage purposes, this should be impossible to reach because the JWT package rejects a token + // without an alg parameter in the header before calling jwt.Keyfunc. + return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch) + } + + return j.getKey(alg, kid) } // base64urlTrailingPadding removes trailing padding before decoding a string from base64url. Some non-RFC compliant