From 655ef0c3a8de9ddf763a2a4e843ea8dd32a52680 Mon Sep 17 00:00:00 2001 From: Micah Parks Date: Thu, 3 Nov 2022 13:01:16 -0400 Subject: [PATCH] Add custom error for invalid ECDSA curve and allow test to panic --- ecdsa.go | 8 +++++++- ecdsa_test.go | 18 +++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ecdsa.go b/ecdsa.go index bc0d989..ca0566d 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -3,6 +3,7 @@ package keyfunc import ( "crypto/ecdsa" "crypto/elliptic" + "errors" "fmt" "math/big" ) @@ -21,6 +22,11 @@ const ( p521 = "P-521" ) +var ( + // ErrECDSACurve indicates an error with the ECDSA curve. + ErrECDSACurve = errors.New("invalid ECDSA curve") +) + // ECDSA parses a jsonWebKey and turns it into an ECDSA public key. func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) { if j.X == "" || j.Y == "" || j.Curve == "" { @@ -49,7 +55,7 @@ func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) { case p521: publicKey.Curve = elliptic.P521() default: - return nil, fmt.Errorf("unknown curve: %s", j.Curve) + return nil, fmt.Errorf("%w: unknown curve: %s", ErrECDSACurve, j.Curve) } // Turn the X coordinate into *big.Int. diff --git a/ecdsa_test.go b/ecdsa_test.go index 20c4346..c71cacb 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -1,10 +1,13 @@ -package keyfunc +package keyfunc_test import ( "encoding/json" + "errors" "testing" "github.com/golang-jwt/jwt/v4" + + "github.com/MicahParks/keyfunc" ) func TestBadCurve(t *testing.T) { @@ -13,18 +16,15 @@ func TestBadCurve(t *testing.T) { someJWT = `eyJhbGciOiJFUzI1NiIsImtpZCI6IjEiLCJ0eXAiOiJKV1QifQ.e30.Q1EeyWUv6XEA0gMLwTFoNhx7Hq1MbVwjI2k9FZPSa-myKW1wYn1X6rHtRyuV-2MEzvimCskFD-afL7UzvdWBQg` ) - jwks, err := NewJSON(json.RawMessage(badJWKS)) + jwks, err := keyfunc.NewJSON(json.RawMessage(badJWKS)) if err != nil { t.Fatalf("Failed to create JWKS from JSON: %v", err) } - defer func() { - if r := recover(); r != nil { - t.Fatalf("panic") - } - }() + // The number of parsed keys should be 0. - if _, err = jwt.Parse(someJWT, jwks.Keyfunc); err == nil { - t.Fatal("No error for bad curve") + _, err = jwt.Parse(someJWT, jwks.Keyfunc) + if !errors.Is(err, keyfunc.ErrKIDNotFound) { + t.Fatalf("Expected ErrKIDNotFound, got %v", err) } }