From aaba8f3000b76b41733567367b9000348a839c17 Mon Sep 17 00:00:00 2001 From: Filip Skokan Date: Thu, 11 Nov 2021 21:47:46 +0100 Subject: [PATCH] fix: createRemoteJWKSet handles all JWS syntaxes --- src/jwks/remote.ts | 29 ++++++++++++--------- test-browser/jwks.js | 4 +-- test-cloudflare-workers/cloudflare.test.mjs | 4 +-- test-deno/jwks.test.ts | 9 ++++--- test/jwks/remote.test.mjs | 20 +++++++------- 5 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/jwks/remote.ts b/src/jwks/remote.ts index 17e4a034fc..2efa597d4b 100644 --- a/src/jwks/remote.ts +++ b/src/jwks/remote.ts @@ -103,23 +103,28 @@ class RemoteJWKSet { return Date.now() < this._cooldownStarted + this._cooldownDuration } - async getKey(protectedHeader: JWSHeaderParameters): Promise { + async getKey(protectedHeader: JWSHeaderParameters, token: FlattenedJWSInput): Promise { + const joseHeader = { + ...protectedHeader, + ...token.header, + } + if (!this._jwks) { await this.reload() } const candidates = this._jwks!.keys.filter((jwk) => { // filter keys based on the mapping of signature algorithms to Key Type - let candidate = jwk.kty === getKtyFromAlg(protectedHeader.alg) + let candidate = jwk.kty === getKtyFromAlg(joseHeader.alg) // filter keys based on the JWK Key ID in the header - if (candidate && typeof protectedHeader.kid === 'string') { - candidate = protectedHeader.kid === jwk.kid + if (candidate && typeof joseHeader.kid === 'string') { + candidate = joseHeader.kid === jwk.kid } // filter keys based on the key's declared Algorithm if (candidate && typeof jwk.alg === 'string') { - candidate = protectedHeader.alg === jwk.alg + candidate = joseHeader.alg === jwk.alg } // filter keys based on the key's declared Public Key Use @@ -133,13 +138,13 @@ class RemoteJWKSet { } // filter out non-applicable OKP Sub Types - if (candidate && protectedHeader.alg === 'EdDSA') { + if (candidate && joseHeader.alg === 'EdDSA') { candidate = jwk.crv === 'Ed25519' || jwk.crv === 'Ed448' } // filter out non-applicable EC curves if (candidate) { - switch (protectedHeader.alg) { + switch (joseHeader.alg) { case 'ES256': candidate = jwk.crv === 'P-256' break @@ -164,7 +169,7 @@ class RemoteJWKSet { if (length === 0) { if (this.coolingDown() === false) { await this.reload() - return this.getKey(protectedHeader) + return this.getKey(joseHeader, token) } throw new JWKSNoMatchingKey() } else if (length !== 1) { @@ -172,17 +177,17 @@ class RemoteJWKSet { } const cached = this._cached.get(jwk) || this._cached.set(jwk, {}).get(jwk)! - if (cached[protectedHeader.alg!] === undefined) { - const keyObject = await importJWK({ ...jwk, ext: true }, protectedHeader.alg!) + if (cached[joseHeader.alg!] === undefined) { + const keyObject = await importJWK({ ...jwk, ext: true }, joseHeader.alg!) if (keyObject instanceof Uint8Array || keyObject.type !== 'public') { throw new JWKSInvalid('JSON Web Key Set members must be public keys') } - cached[protectedHeader.alg!] = keyObject + cached[joseHeader.alg!] = keyObject } - return cached[protectedHeader.alg!] + return cached[joseHeader.alg!] } async reload() { diff --git a/test-browser/jwks.js b/test-browser/jwks.js index 3806d58ef8..fe5f89c362 100644 --- a/test-browser/jwks.js +++ b/test-browser/jwks.js @@ -7,11 +7,11 @@ QUnit.test('fetches the JWKSet', async (assert) => { const { alg, kid } = response.keys[0] const jwks = createRemoteJWKSet(new URL(jwksUri)) await assert.rejects( - jwks({ alg: 'RS256' }), + jwks({ alg: 'RS256' }, {}), 'multiple matching keys found in the JSON Web Key Set', ) await assert.rejects( - jwks({ kid: 'foo', alg: 'RS256' }), + jwks({ kid: 'foo', alg: 'RS256' }, {}), 'no applicable key found in the JSON Web Key Set', ) assert.ok(await jwks({ alg, kid })) diff --git a/test-cloudflare-workers/cloudflare.test.mjs b/test-cloudflare-workers/cloudflare.test.mjs index e723f520b5..e414479cae 100644 --- a/test-cloudflare-workers/cloudflare.test.mjs +++ b/test-cloudflare-workers/cloudflare.test.mjs @@ -346,13 +346,13 @@ test('createRemoteJWKSet', macro, async () => { const response = await fetch(jwksUri).then((r) => r.json()) const { alg, kid } = response.keys[0] const jwks = jose.createRemoteJWKSet(new URL(jwksUri)) - await jwks({ alg, kid }) + await jwks({ alg, kid }, {}) }) test('remote jwk set timeout', macro, async () => { const jwksUri = 'https://www.googleapis.com/oauth2/v3/certs' const jwks = jose.createRemoteJWKSet(new URL(jwksUri), { timeoutDuration: 0 }) - await jwks({ alg: 'RS256' }).then( + await jwks({ alg: 'RS256' }, {}).then( () => { throw new Error('should fail') }, diff --git a/test-deno/jwks.test.ts b/test-deno/jwks.test.ts index a35c160e19..80ec6299f0 100644 --- a/test-deno/jwks.test.ts +++ b/test-deno/jwks.test.ts @@ -1,6 +1,7 @@ import { assertThrowsAsync } from 'https://deno.land/std@0.109.0/testing/asserts.ts' import { createRemoteJWKSet, errors } from '../dist/deno/index.ts' +import type { FlattenedJWSInput } from '../dist/deno/index.ts' const jwksUri = 'https://www.googleapis.com/oauth2/v3/certs' @@ -9,23 +10,23 @@ Deno.test('fetches the JWKSet', async () => { const { alg, kid } = response.keys[0] const jwks = createRemoteJWKSet(new URL(jwksUri)) await assertThrowsAsync( - () => jwks({ alg: 'RS256' }, null), + () => jwks({ alg: 'RS256' }, {}), errors.JWKSMultipleMatchingKeys, 'multiple matching keys found in the JSON Web Key Set', ) await assertThrowsAsync( - () => jwks({ kid: 'foo', alg: 'RS256' }, null), + () => jwks({ kid: 'foo', alg: 'RS256' }, {}), errors.JWKSNoMatchingKey, 'no applicable key found in the JSON Web Key Set', ) - await jwks({ alg, kid }, null) + await jwks({ alg, kid }, {}) }) Deno.test('timeout', async () => { const server = Deno.listen({ port: 3000 }) const jwks = createRemoteJWKSet(new URL('http://localhost:3000'), { timeoutDuration: 0 }) await assertThrowsAsync( - () => jwks({ alg: 'RS256' }, null), + () => jwks({ alg: 'RS256' }, {}), errors.JWKSTimeout, 'request timed out', ).finally(async () => { diff --git a/test/jwks/remote.test.mjs b/test/jwks/remote.test.mjs index 708481beb5..123fc314de 100644 --- a/test/jwks/remote.test.mjs +++ b/test/jwks/remote.test.mjs @@ -209,19 +209,19 @@ test.serial('throws on invalid JWKSet', async (t) => { const url = new URL('https://as.example.com/jwks') const JWKS = createRemoteJWKSet(url) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ERR_JWKS_INVALID', message: 'JSON Web Key Set malformed', }) scope.get('/jwks').once().reply(200, {}) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ERR_JWKS_INVALID', message: 'JSON Web Key Set malformed', }) scope.get('/jwks').once().reply(200, { keys: null }) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ERR_JWKS_INVALID', message: 'JSON Web Key Set malformed', }) @@ -230,19 +230,19 @@ test.serial('throws on invalid JWKSet', async (t) => { .get('/jwks') .once() .reply(200, { keys: [null] }) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ERR_JWKS_INVALID', message: 'JSON Web Key Set malformed', }) scope.get('/jwks').once().reply(404) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ERR_JOSE_GENERIC', message: 'Expected 200 OK from the JSON Web Key Set HTTP response', }) scope.get('/jwks').once().reply(200, '{') - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ERR_JOSE_GENERIC', message: 'Failed to parse the JSON Web Key Set HTTP response as JSON', }) @@ -252,7 +252,7 @@ test('handles ENOTFOUND', async (t) => { nock.enableNetConnect() const url = new URL('https://op.example.com/jwks') const JWKS = createRemoteJWKSet(url) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ENOTFOUND', }) }) @@ -261,7 +261,7 @@ test('handles ECONNREFUSED', async (t) => { nock.enableNetConnect() const url = new URL('http://localhost:3001/jwks') const JWKS = createRemoteJWKSet(url) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ECONNREFUSED', }) }) @@ -273,7 +273,7 @@ test('handles ECONNRESET', async (t) => { socket.destroy() }) const JWKS = createRemoteJWKSet(url) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ECONNRESET', }) }) @@ -285,7 +285,7 @@ test('handles a timeout', async (t) => { const JWKS = createRemoteJWKSet(url, { timeoutDuration: 500, }) - await t.throwsAsync(JWKS({ alg: 'RS256' }), { + await t.throwsAsync(JWKS({ alg: 'RS256' }, {}), { code: 'ERR_JWKS_TIMEOUT', }) })