Skip to content

Commit

Permalink
Merge pull request #67 from MicahParks/alg_check_edits
Browse files Browse the repository at this point in the history
Add custom error for invalid ECDSA curve and allow test to panic
  • Loading branch information
MicahParks committed Nov 3, 2022
2 parents d64cede + 655ef0c commit 338d24e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
8 changes: 7 additions & 1 deletion ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package keyfunc
import (
"crypto/ecdsa"
"crypto/elliptic"
"errors"
"fmt"
"math/big"
)
Expand All @@ -21,6 +22,11 @@ const (
p521 = "P-521"
)

var (
// ErrECDSACurve indicates an error with the ECDSA curve.
ErrECDSACurve = errors.New("invalid ECDSA curve")
)

// ECDSA parses a jsonWebKey and turns it into an ECDSA public key.
func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {
if j.X == "" || j.Y == "" || j.Curve == "" {
Expand Down Expand Up @@ -49,7 +55,7 @@ func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {
case p521:
publicKey.Curve = elliptic.P521()
default:
return nil, fmt.Errorf("unknown curve: %s", j.Curve)
return nil, fmt.Errorf("%w: unknown curve: %s", ErrECDSACurve, j.Curve)
}

// Turn the X coordinate into *big.Int.
Expand Down
18 changes: 9 additions & 9 deletions ecdsa_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package keyfunc
package keyfunc_test

import (
"encoding/json"
"errors"
"testing"

"github.com/golang-jwt/jwt/v4"

"github.com/MicahParks/keyfunc"
)

func TestBadCurve(t *testing.T) {
Expand All @@ -13,18 +16,15 @@ func TestBadCurve(t *testing.T) {
someJWT = `eyJhbGciOiJFUzI1NiIsImtpZCI6IjEiLCJ0eXAiOiJKV1QifQ.e30.Q1EeyWUv6XEA0gMLwTFoNhx7Hq1MbVwjI2k9FZPSa-myKW1wYn1X6rHtRyuV-2MEzvimCskFD-afL7UzvdWBQg`
)

jwks, err := NewJSON(json.RawMessage(badJWKS))
jwks, err := keyfunc.NewJSON(json.RawMessage(badJWKS))
if err != nil {
t.Fatalf("Failed to create JWKS from JSON: %v", err)
}

defer func() {
if r := recover(); r != nil {
t.Fatalf("panic")
}
}()
// The number of parsed keys should be 0.

if _, err = jwt.Parse(someJWT, jwks.Keyfunc); err == nil {
t.Fatal("No error for bad curve")
_, err = jwt.Parse(someJWT, jwks.Keyfunc)
if !errors.Is(err, keyfunc.ErrKIDNotFound) {
t.Fatalf("Expected ErrKIDNotFound, got %v", err)
}
}

0 comments on commit 338d24e

Please sign in to comment.