Skip to content

Commit

Permalink
Use context from ClientHello during GetCertificate (#249)
Browse files Browse the repository at this point in the history
* Use context from ClientHello during GetCertificate

(see #247)

* Avoid recursive ops during on-demand issuance
  • Loading branch information
mholt authored Aug 17, 2023
1 parent 5bca6d1 commit e822453
Showing 1 changed file with 32 additions and 33 deletions.
65 changes: 32 additions & 33 deletions handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ import (
// GetCertificate will run in a new context, use GetCertificateWithContext to provide
// a context.
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx := context.TODO() // TODO: get a proper context? from somewhere...
return cfg.GetCertificateWithContext(ctx, clientHello)
return cfg.GetCertificateWithContext(clientHello.Context(), clientHello)
}

func (cfg *Config) GetCertificateWithContext(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expand Down Expand Up @@ -276,15 +275,15 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client
name := cfg.getNameFromClientHello(hello)

// By this point, we need to load or obtain a certificate. If a swarm of requests comes in for the same
// domain, avoid pounding manager or storage thousands of times simultaneously. We do a similar sync
// domain, avoid pounding manager or storage thousands of times simultaneously. We use a similar sync
// strategy for obtaining certificate during handshake.
certLoadWaitChansMu.Lock()
wait, ok := certLoadWaitChans[name]
if ok {
// another goroutine is already loading the cert; just wait and we'll get it from the in-memory cache
certLoadWaitChansMu.Unlock()

timeout := time.NewTimer(2 * time.Minute) // TODO: have Caddy use the context param to establish a timeout
timeout := time.NewTimer(2 * time.Minute)
select {
case <-timeout.C:
return Certificate{}, fmt.Errorf("timed out waiting to load certificate for %s", name)
Expand Down Expand Up @@ -480,6 +479,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock()

log.Debug("new certificate is needed, but is already being obtained; waiting for that issuance to complete",
zap.String("subject", name))

// TODO: see if we can get a proper context in here, for true cancellation
timeout := time.NewTimer(2 * time.Minute)
select {
Expand All @@ -489,7 +491,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
timeout.Stop()
}

return cfg.loadCertFromStorage(ctx, log, hello)
// it should now be loaded in the cache, ready to go; if not,
// the goroutine in charge of that probably had an error
return cfg.getCertDuringHandshake(ctx, hello, false)
}

// looks like it's up to us to do all the work and obtain the cert.
Expand All @@ -507,28 +511,28 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli

log.Info("obtaining new certificate", zap.String("server_name", name))

// TODO: we are only adding a timeout because we don't know if the context passed in is actually cancelable...
// set a timeout so we don't inadvertently hold a client handshake open too long
// (timeout duration is based on https://caddy.community/t/zerossl-dns-challenge-failing-often-route53-plugin/13822/24?u=matt)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 180*time.Second)
defer cancel()

// Obtain the certificate
// obtain the certificate (this puts it in storage) and if successful,
// load it from storage so we and any other waiting goroutine can use it
var cert Certificate
err := cfg.ObtainCertAsync(ctx, name)
if err == nil {
// load from storage while others wait to make the op as atomic as possible
cert, err = cfg.loadCertFromStorage(ctx, log, hello)
if err != nil {
log.Error("loading newly-obtained certificate from storage", zap.String("server_name", name), zap.Error(err))
}
}

// immediately unblock anyone waiting for it; doing this in
// a defer would risk deadlock because of the recursive call
// to getCertDuringHandshake below when we return!
// immediately unblock anyone waiting for it
unblockWaiters()

if err != nil {
// shucks; failed to solve challenge on-demand
return Certificate{}, err
}

// success; certificate was just placed on disk, so
// we need only restart serving the certificate
return cfg.loadCertFromStorage(ctx, log, hello)
return cert, err
}

// handshakeMaintenance performs a check on cert for expiration and OCSP validity.
Expand Down Expand Up @@ -611,7 +615,7 @@ func (cfg *Config) handshakeMaintenance(ctx context.Context, hello *tls.ClientHe
//
// This function is safe for use by multiple concurrent goroutines.
func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.ClientHelloInfo, currentCert Certificate) (Certificate, error) {
log := cfg.Logger.Named("on_demand")
log := logWithRemote(cfg.Logger.Named("on_demand"), hello)

name := cfg.getNameFromClientHello(hello)
timeLeft := time.Until(expiresAt(currentCert.Leaf))
Expand Down Expand Up @@ -651,7 +655,9 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
timeout.Stop()
}

return cfg.loadCertFromStorage(ctx, log, hello)
// it should now be loaded in the cache, ready to go; if not,
// the goroutine in charge of that probably had an error
return cfg.getCertDuringHandshake(ctx, hello, false)
}

// looks like it's up to us to do all the work and renew the cert
Expand Down Expand Up @@ -703,16 +709,8 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
} else {
err = cfg.RenewCertAsync(ctx, name, false)
if err == nil {
// even though the recursive nature of the dynamic cert loading
// would just call this function anyway, we do it here to
// make the replacement as atomic as possible.
newCert, err = cfg.CacheManagedCertificate(ctx, name)
if err != nil {
log.Error("loading renewed certificate", zap.String("server_name", name), zap.Error(err))
} else {
// replace the old certificate with the new one
cfg.certCache.replaceCertificate(currentCert, newCert)
}
// load from storage while in lock to make the replacement as atomic as possible
newCert, err = cfg.reloadManagedCertificate(ctx, currentCert)
}
}

Expand All @@ -722,11 +720,10 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
unblockWaiters()

if err != nil {
log.Error("renewing and reloading certificate", zap.Error(err))
return newCert, err
log.Error("renewing and reloading certificate", zap.String("server_name", name), zap.Error(err))
}

return cfg.loadCertFromStorage(ctx, log, hello)
return newCert, err
}

// if the certificate hasn't expired, we can serve what we have and renew in the background
Expand Down Expand Up @@ -872,6 +869,8 @@ var (
obtainCertWaitChans = make(map[string]chan struct{})
obtainCertWaitChansMu sync.Mutex
)

// TODO: this lockset should probably be per-cache
var (
certLoadWaitChans = make(map[string]chan struct{})
certLoadWaitChansMu sync.Mutex
Expand Down

0 comments on commit e822453

Please sign in to comment.