Skip to content

Commit

Permalink
Merge pull request #70 from sean-rn/given-key-alg
Browse files Browse the repository at this point in the history
Allow creation of GivenKeys with a required algorithm
  • Loading branch information
MicahParks committed Nov 26, 2022
2 parents 338d24e + 06dfbca commit d327d9a
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 3 deletions.
19 changes: 16 additions & 3 deletions given.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ import (

// GivenKey represents a cryptographic key that resides in a JWKS. In conjuncture with Options.
type GivenKey struct {
inter interface{}
inter interface{}
algorithm string
}

// NewGiven creates a JWKS from a map of given keys.
func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) {
keys := make(map[string]parsedJWK)

for kid, given := range givenKeys {
keys[kid] = parsedJWK{public: given.inter}
keys[kid] = parsedJWK{public: given.inter, algorithm: given.algorithm}
}

return &JWKS{
Expand All @@ -25,7 +26,7 @@ func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) {
}

// NewGivenCustom creates a new GivenKey given an untyped variable. The key argument is expected to be a supported
// by the jwt package used.
// by the jwt package used. To specify a required algorithm use NewGivenCustomAlg.
//
// See the https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod function for registering an unsupported
// signing method.
Expand All @@ -35,6 +36,18 @@ func NewGivenCustom(key interface{}) (givenKey GivenKey) {
}
}

// NewGivenCustomAlg creates a new GivenKey given an untyped variable and an algorithm. The key argument is expected to
// be a type supported by the jwt package used. The alg argument will be validated against the alg header of tokens.
//
// See the https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod function for registering an unsupported
// signing method.
func NewGivenCustomAlg(key interface{}, alg string) (givenKey GivenKey) {
return GivenKey{
inter: key,
algorithm: alg,
}
}

// NewGivenECDSA creates a new GivenKey given an ECDSA public key.
func NewGivenECDSA(key *ecdsa.PublicKey) (givenKey GivenKey) {
return GivenKey{
Expand Down
52 changes: 52 additions & 0 deletions given_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"errors"
"fmt"
"testing"

Expand Down Expand Up @@ -45,6 +46,57 @@ func TestNewGivenCustom(t *testing.T) {
signParseValidate(t, token, key, jwks)
}

// TestNewGivenCustomAlg tests that a custom jwt.SigningMethod can be used to create a JWKS and a proper jwt.Keyfunc.
func TestNewGivenCustomAlg(t *testing.T) {
jwt.RegisterSigningMethod(method.CustomAlg, func() jwt.SigningMethod {
return method.EmptyCustom{}
})

const key = "test-key"
givenKeys := make(map[string]keyfunc.GivenKey)
givenKeys[testKID] = keyfunc.NewGivenCustomAlg(key, method.CustomAlg)

jwks := keyfunc.NewGiven(givenKeys)

token := jwt.New(method.EmptyCustom{})
token.Header[algAttribute] = method.CustomAlg
token.Header[kidAttribute] = testKID

signParseValidate(t, token, key, jwks)
}

// TestNewGivenCustomAlg_NegativeCase tests that a custom jwt.SigningMethod can be used to create
// a JWKS and a proper jwt.Keyfunc and that a token with a non-matching algorithm will be rejected.
func TestNewGivenCustomAlg_NegativeCase(t *testing.T) {
jwt.RegisterSigningMethod(method.CustomAlg, func() jwt.SigningMethod {
return method.EmptyCustom{}
})

const key = jwt.UnsafeAllowNoneSignatureType // So golang-jwt isn't the one blocking this test
givenKeys := make(map[string]keyfunc.GivenKey)
givenKeys[testKID] = keyfunc.NewGivenCustomAlg(key, method.CustomAlg)

jwks := keyfunc.NewGiven(givenKeys)

token := jwt.New(method.EmptyCustom{})
token.Header[algAttribute] = jwt.SigningMethodNone.Alg()
token.Header[kidAttribute] = testKID

jwtB64, err := token.SignedString(key)
if err != nil {
t.Fatalf(logFmt, "Failed to sign the JWT.", err)
}

parsed, err := jwt.NewParser().Parse(jwtB64, jwks.Keyfunc)
if !errors.Is(err, keyfunc.ErrJWKAlgMismatch) {
t.Fatalf("Failed to return ErrJWKAlgMismatch")
}

if parsed.Valid {
t.Fatalf("The JWT was valid.")
}
}

// TestNewGivenKeyECDSA tests that a generated ECDSA key can be added to the JWKS and create a proper jwt.Keyfunc.
func TestNewGivenKeyECDSA(t *testing.T) {
givenKeys := make(map[string]keyfunc.GivenKey)
Expand Down
10 changes: 10 additions & 0 deletions jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ func (j *JWKS) KIDs() (kids []string) {
return kids
}

// KeyAlg returns the algorithm (`alg`) for the key identified by Key ID (`kid`).
func (j *JWKS) KeyAlg(kid string) string {
j.mux.RLock()
defer j.mux.RUnlock()
if pubKey, ok := j.keys[kid]; ok {
return pubKey.algorithm
}
return ""
}

// Len returns the number of keys in the JWKS.
func (j *JWKS) Len() int {
j.mux.RLock()
Expand Down
32 changes: 32 additions & 0 deletions jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,38 @@ func TestJWKS_Len(t *testing.T) {
}
}

// TestJWKS_KeyAlg confirms the JWKS.Len returns the algorithm for keys by kid.
func TestJWKS_KeyAlg(t *testing.T) {
jwks, err := keyfunc.NewJSON([]byte(jwksJSON))
if err != nil {
t.Fatalf(logFmt, "Failed to create a JWKS from JSON.", err)
}

expectedAlgs := map[string]string{
"zXew0UJ1h6Q4CCcd_9wxMzvcp5cEBifH0KWrCz2Kyxc": "PS256",
"ebJxnm9B3QDBljB5XJWEu72qx6BawDaMAhwz4aKPkQ0": "ES512",
"TVAAet63O3xy_KK6_bxVIu7Ra3_z1wlB543Fbwi5VaU": "ES384",
"arlUxX4hh56rNO-XdIPhDT7bqBMqcBwNQuP_TnZJNGs": "RS512",
"tW6ae7TomE6_2jooM-sf9N_6lWg7HNtaQXrDsElBzM4": "PS512",
"Lx1FmayP2YBtxaqS1SKJRJGiXRKnw2ov5WmYIMG-BLE": "PS384",
"gnmAfvmlsi3kKH3VlM1AJ85P2hekQ8ON_XvJqs3xPD8": "RS384",
"CGt0ZWS4Lc5faiKSdi0tU0fjCAdvGROQRGU9iR7tV0A": "ES256",
"C65q0EKQyhpd1m4fr7SKO2He_nAxgCtAdws64d2BLt8": "RS256",
"Q56A": "",
"hmac": "",
"WW91IGdldCBhIGdvbGQgc3RhciDwn4yfCg": "",
}

for kid, expectedAlg := range expectedAlgs {
t.Run(kid, func(t *testing.T) {
actualAlg := jwks.KeyAlg(kid)
if actualAlg != expectedAlg {
t.Errorf("Unexpected alg for key %v.\n Expected: %v\n Actual: %v\n", kid, expectedAlg, actualAlg)
}
})
}
}

// TestRateLimit performs a test to confirm the rate limiter works as expected.
func TestRateLimit(t *testing.T) {
tempDir, err := os.MkdirTemp("", "*")
Expand Down

0 comments on commit d327d9a

Please sign in to comment.