Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor caching to be optional, modular and overrideable. #2

Merged
merged 1 commit into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
refactor caching to be optional, modular and overrideable.
  • Loading branch information
nathanejohnson committed Aug 29, 2022
commit 151decd6a8a93df9f6f3be7784bea57a049d7158
102 changes: 102 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package intransport

import (
"crypto/x509"
"sync"
)

// LockedCachedCertRepresenter - awkwardly named interface for a cached and locked certificate entry.
type LockedCachedCertRepresenter interface {
// Cert is the getter function and will be called on a locked entry.
// A nil value is valid as a return, and signals we need to fetch the certificate.
Cert() *x509.Certificate

// SetCert is the certificate setter function and will be called on a locked entry.
SetCert(cert *x509.Certificate)

// Unlock will be called after fetching and / or setting the value. Once Unlock is called,
// no other calls will be made. For subsequent access, Cacher.LockedCachedCert will be called again.
Unlock()
}

// Cacher - interface for caching x509 entries.
type Cacher interface {
// LockedCachedCert will be called for a key, and this should
// return a locked entry.
LockedCachedCert(key string) LockedCachedCertRepresenter
}

type certCacheEntry struct {
sync.Mutex
cert *x509.Certificate
}

func (cce *certCacheEntry) Cert() *x509.Certificate {
return cce.cert
}

func (cce *certCacheEntry) SetCert(cert *x509.Certificate) {
cce.cert = cert
}

type certCache struct {
sync.Mutex
m map[string]*certCacheEntry
}

// NewMapCache - returns a Cacher implementation based on a go map and mutexes.
func NewMapCache() Cacher {
return &certCache {
m: make(map[string]*certCacheEntry),
}
}

func (cc *certCache) LockedCachedCert(key string) LockedCachedCertRepresenter {
cc.Lock()
cce, ok := cc.m[key]
if ok {
cc.Unlock()
cce.Lock()
return cce
}

// cache miss
cce = &certCacheEntry{}
cce.Lock()
cc.m[key] = cce
cc.Unlock()
return cce
}

type nopCache struct {}

type nopCacheEntry struct {}

// NewNopCache - this returns a nop cache, which can be used to disable caching of certificates.
func NewNopCache() Cacher {
return nopCache{}
}

func (nc nopCache) LockedCachedCert(_ string) LockedCachedCertRepresenter {
return nopCacheEntry{}
}

func (nce nopCacheEntry) LockEntry() {
return
}

func (nce nopCacheEntry) Unlock() {
return
}

func (nce nopCacheEntry) UnlockCacher() {
return
}

func (nce nopCacheEntry) Cert() *x509.Certificate {
return nil
}

func (nce nopCacheEntry) SetCert(_ *x509.Certificate) {
return
}
122 changes: 52 additions & 70 deletions intransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,11 @@ import (
"net"
"net/http"
"strings"
"sync"
"time"

"golang.org/x/crypto/ocsp"
)

type certCacheEntry struct {
sync.RWMutex
cert *x509.Certificate
}

type certCache struct {
sync.Mutex
m map[string]*certCacheEntry
c *http.Client
}

// StatusRequestExtension - status_request
const StatusRequestExtension = 5

Expand All @@ -48,6 +36,11 @@ var (
// OCSP staple was not found.
ErrOCSPNotStapled = errors.New("certificate was marked with OCSP must-staple and no staple could be verified")

// ErrNoCertificates - this is returned in the unlikely event that no
// peer certificates are provided whatsoever. This should never be
// seen.
ErrNoCertificates = errors.New("no certificates supplied")

// MustStapleValue is the value in the MustStaple extension.
// DER encoding of []int{5}.
// https://tools.ietf.org/html/rfc6066#section-1.1
Expand All @@ -57,31 +50,6 @@ var (
// Must staple oid is id-pe-tlsfeature as defined here
// https://tools.ietf.org/html/rfc7633#section-6
MustStapleOID = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 24}

cc = &certCache{
m: make(map[string]*certCacheEntry),
// client used for fetching intermediates.
c: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 0,
}).DialContext,

// Since we cache responses, all http activity should be
// one-and-done.
DisableKeepAlives: true,

// This shouldn't be needed, since I don't believe
// the server url locations are ever TLS enabled?
TLSHandshakeTimeout: 3 * time.Second,

// This also shouldn't be needed, but doesn't hurt anything
ExpectContinueTimeout: 1 * time.Second,
},
},
}
)

// peerCertViewer - this is a method type that is plugged into a tls.Config.VerifyPeerCertificate,
Expand Down Expand Up @@ -113,6 +81,12 @@ func NewInTransport(tlsc *tls.Config) http.RoundTripper {
// tlsc.VerifyConnection is specified, it will be called with the same semantics as before,
// but after we validate stapled ocsp.
func NewInTransportFromTransport(t *http.Transport, dialer *net.Dialer, tlsc *tls.Config) http.RoundTripper {
return NewInTransportFromTransportWithCache(t, dialer, tlsc, nil)
}

// NewInTransportFromTransportWithCache - Same as NewInTransportFromTransport, with the option of specifying
// a cache implementation for fetched intermediates. If nil, the default cacher will use the map cache implementation.
func NewInTransportFromTransportWithCache(t *http.Transport, dialer *net.Dialer, tlsc *tls.Config, cache Cacher) http.RoundTripper {
if dialer == nil {
dialer = &net.Dialer{
Timeout: 30 * time.Second,
Expand Down Expand Up @@ -140,10 +114,34 @@ func NewInTransportFromTransport(t *http.Transport, dialer *net.Dialer, tlsc *tl
} else {
tlsc = tlsc.Clone()
}
if cache == nil {
cache = NewMapCache()
}
it := &inTranspoort{
Transport: t,
NextVerifyPeerCertificate: tlsc.VerifyPeerCertificate,
Dialer: dialer,
certFetcher: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 0,
}).DialContext,

// Since we cache responses, all http activity should be
// one-and-done.
DisableKeepAlives: true,

// This shouldn't be needed, since I don't believe
// the server url locations are ever TLS enabled?
TLSHandshakeTimeout: 3 * time.Second,

// This also shouldn't be needed, but doesn't hurt anything
ExpectContinueTimeout: 1 * time.Second,
},
},
cache: cache,
}

it.TLS = tlsc
Expand Down Expand Up @@ -179,6 +177,7 @@ func NewInTransportFromTransport(t *http.Transport, dialer *net.Dialer, tlsc *tl
return d.DialContext(ctx, network, addr)
}
return it

}

// inTranspoort - this implements an http.RoundTripper and handles the fetching
Expand All @@ -201,6 +200,10 @@ type inTranspoort struct {
TLSHandshakeTimeout time.Duration

Dialer *net.Dialer

certFetcher *http.Client

cache Cacher
}

func (it *inTranspoort) validateOCSP(serverName string, connState *tls.ConnectionState) error {
Expand Down Expand Up @@ -268,7 +271,7 @@ func (it *inTranspoort) validateOCSP(serverName string, connState *tls.Connectio
return fmt.Errorf("invalid ocsp validation: %s", ocsp.ResponseStatus(ocspResp.Status).String())
}

if !ocspResp.NextUpdate.Before(time.Now()) {
if ocspResp.NextUpdate.After(time.Now()) {
// for now, don't fail on an expired staple unless must staple is specified.
// maybe revisit this
validatedStaple = true
Expand All @@ -294,11 +297,6 @@ func parseHost(host string) (string, error) {
return host, nil
}

// NoCertificatesErr - this is returned in the unlikely event that no
// peer certificates are provided whatsoever. This should never be
// seen.
var NoCertificatesErr = errors.New("no certificates supplied")

// verifyPeerCertificate - The difference between this
// and the default TLS verification is that missing intermediates will be
// fetched until either a valid path to a trusted root is found or no further
Expand All @@ -308,7 +306,7 @@ var NoCertificatesErr = errors.New("no certificates supplied")
// method returns an error, it will stop the connection.
func (it *inTranspoort) verifyPeerCertificate(serverName string, rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return NoCertificatesErr
return ErrNoCertificates
}

PeerCertificates := make([]*x509.Certificate, 0, len(rawCerts))
Expand Down Expand Up @@ -396,7 +394,7 @@ func (it *inTranspoort) buildMissingChain(cert *x509.Certificate) ([]*x509.Certi
break
}
var err error
tmpCert, err = fetchIssuingCert(tmpCert)
tmpCert, err = it.fetchIssuingCert(tmpCert)

if err != nil {
return nil, err
Expand All @@ -410,7 +408,7 @@ func (it *inTranspoort) buildMissingChain(cert *x509.Certificate) ([]*x509.Certi
}

// This grabs the issuing cert from the issuing certificate extension.
func fetchIssuingCert(cert *x509.Certificate) (*x509.Certificate, error) {
func (it *inTranspoort) fetchIssuingCert(cert *x509.Certificate) (*x509.Certificate, error) {
// this attempts to do two things:
// 1) avoid stampede problem - minimizes fetches of a cert on cache miss
// 2) avoid long locks on the outer map.
Expand All @@ -426,38 +424,22 @@ func fetchIssuingCert(cert *x509.Certificate) (*x509.Certificate, error) {
} else {
mapKey = cert.Issuer.CommonName
}
cc.Lock()
cce, ok := cc.m[mapKey]
if ok {
cc.Unlock()
cce.Lock()
crt := cce.cert
// crt may be nil if fetch failed on a prior attempt.
// also re-fetch if cert is expired.
if crt != nil && crt.NotAfter.After(time.Now()) {
cce.Unlock()
return crt, nil
}
} else {
cce = new(certCacheEntry)
cce.Lock()
cc.m[mapKey] = cce
cc.Unlock()

cce := it.cache.LockedCachedCert(mapKey)
var crt *x509.Certificate
if crt = cce.Cert(); crt != nil && crt.NotAfter.After(time.Now()) {
cce.Unlock()
return crt, nil
}

// Once we're here, cce is locked, cc is unlocked
// Now we attempt to fetch the issuing cert.
// defer is nowhere near as slow as the code below
defer cce.Unlock()

// I've yet to see more than one IssuingCertificateURL,
// but just in case...
var err error
var fetchedCert *x509.Certificate
for _, url := range cert.IssuingCertificateURL {
for _, urlString := range cert.IssuingCertificateURL {
var resp *http.Response
resp, err = cc.c.Get(url)
resp, err = it.certFetcher.Get(urlString)
if err != nil {
continue
}
Expand All @@ -472,7 +454,7 @@ func fetchIssuingCert(cert *x509.Certificate) (*x509.Certificate, error) {
if err != nil {
continue
}
cce.cert = fetchedCert
cce.SetCert(fetchedCert)
break
}
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion intransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func TestMissingIntermediates(t *testing.T) {
return nil
}
}
trans := NewInTransportFromTransport(tr, d, tlsc)
trans := NewInTransportFromTransportWithCache(tr, d, tlsc, NewMapCache())

c := &http.Client{Transport: trans}

Expand Down