diff --git a/internal/gatewayapi/securitypolicy.go b/internal/gatewayapi/securitypolicy.go index 61ed8b64a00..4dca669f882 100644 --- a/internal/gatewayapi/securitypolicy.go +++ b/internal/gatewayapi/securitypolicy.go @@ -43,6 +43,7 @@ const ( defaultForwardAccessToken = false defaultRefreshToken = true defaultPassThroughAuthHeader = false + defaultOIDCHTTPTimeout = 5 * time.Second // nolint: gosec oidcHMACSecretName = "envoy-oidc-hmac" @@ -55,6 +56,10 @@ func (t *Translator) ProcessSecurityPolicies(securityPolicies []*egv1a1.Security resources *resource.Resources, xdsIR resource.XdsIRMap, ) []*egv1a1.SecurityPolicy { + // Cache is only reused during one translation across multiple routes and gateways. + // The failed fetches will be retried in the next translation when the provider resources are reconciled again. + t.oidcDiscoveryCache = newOIDCDiscoveryCache() + // SecurityPolicies are already sorted by the provider layer // First build a map out of the routes and gateways for faster lookup since users might have thousands of routes or more. @@ -1419,7 +1424,7 @@ func (t *Translator) buildOIDCProvider(policy *egv1a1.SecurityPolicy, resources // EG assumes that the issuer url uses the same protocol and CA as the token endpoint. // If we need to support different protocols or CAs, we need to add more fields to the OIDCProvider CRD. if provider.TokenEndpoint == nil || provider.AuthorizationEndpoint == nil { - discoveredConfig, err := fetchEndpointsFromIssuer(provider.Issuer, providerTLS) + discoveredConfig, err := t.fetchEndpointsFromIssuer(provider.Issuer, providerTLS) if err != nil { return nil, err } @@ -1507,7 +1512,25 @@ func (o *OpenIDConfig) validate() error { return nil } -func fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfig) (*OpenIDConfig, error) { +func (t *Translator) fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfig) (*OpenIDConfig, error) { + if config, cachedErr, ok := t.oidcDiscoveryCache.Get(issuerURL); ok { + if cachedErr != nil { + return nil, cachedErr + } + return config, nil + } + + config, err := discoverEndpointsFromIssuer(issuerURL, providerTLS) + if err != nil { + t.oidcDiscoveryCache.Set(issuerURL, nil, err) + return nil, err + } + + t.oidcDiscoveryCache.Set(issuerURL, config, nil) + return config, nil +} + +func discoverEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfig) (*OpenIDConfig, error) { var ( tlsConfig *tls.Config err error @@ -1519,7 +1542,7 @@ func fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfi } } - client := &http.Client{} + client := &http.Client{Timeout: defaultOIDCHTTPTimeout} if tlsConfig != nil { client.Transport = &http.Transport{ TLSClientConfig: tlsConfig, @@ -1564,6 +1587,47 @@ func fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfi return &config, nil } +// oidcDiscoveryCache is a cache for auto-discovered OIDC configurations from the issuer's well-known URL. +// The cache is only used within the current translation, so no need to lock it or expire entries. +type oidcDiscoveryCache struct { + entries map[string]cachedOIDCEntry +} + +type cachedOIDCEntry struct { + config *OpenIDConfig + err error +} + +func newOIDCDiscoveryCache() *oidcDiscoveryCache { + return &oidcDiscoveryCache{ + entries: make(map[string]cachedOIDCEntry), + } +} + +func (c *oidcDiscoveryCache) Get(issuer string) (*OpenIDConfig, error, bool) { + if c == nil { + return nil, nil, false + } + + entry, ok := c.entries[issuer] + if !ok { + return nil, nil, false + } + + return entry.config, entry.err, true +} + +func (c *oidcDiscoveryCache) Set(issuer string, cfg *OpenIDConfig, err error) { + if c == nil { + return + } + + c.entries[issuer] = cachedOIDCEntry{ + config: cfg, + err: err, + } +} + func retryable(code int) bool { return code >= 500 && (code != http.StatusNotImplemented && diff --git a/internal/gatewayapi/securitypolicy_test.go b/internal/gatewayapi/securitypolicy_test.go index 62fad681bb5..16067ef8a75 100644 --- a/internal/gatewayapi/securitypolicy_test.go +++ b/internal/gatewayapi/securitypolicy_test.go @@ -6,7 +6,11 @@ package gatewayapi import ( + "fmt" + "net/http" + "net/http/httptest" "regexp" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -780,6 +784,79 @@ func TestValidateCIDRs_ErrorOnBadCIDR(t *testing.T) { } } +func TestTranslatorFetchEndpointsFromIssuerCache(t *testing.T) { + var ( + callCount atomic.Int32 + server *httptest.Server + ) + + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + + callCount.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{"token_endpoint":%q,"authorization_endpoint":%q}`, server.URL+"/token", server.URL+"/authorize") + })) + defer server.Close() + + tr := &Translator{GatewayControllerName: "gateway.envoyproxy.io/gatewayclass-controller"} + tr.oidcDiscoveryCache = newOIDCDiscoveryCache() + + cfg, err := tr.fetchEndpointsFromIssuer(server.URL, nil) + require.NoError(t, err) + require.NotNil(t, cfg) + require.Equal(t, int32(1), callCount.Load()) + + cfgCached, err := tr.fetchEndpointsFromIssuer(server.URL, nil) + require.NoError(t, err) + require.NotNil(t, cfgCached) + require.Equal(t, int32(1), callCount.Load(), "second fetch should use cache") + + cfgAgain, err := tr.fetchEndpointsFromIssuer(server.URL, nil) + require.NoError(t, err) + require.NotNil(t, cfgAgain) + require.Equal(t, int32(1), callCount.Load(), "subsequent fetch should continue using cache") +} + +func TestTranslatorFetchEndpointsFromIssuerCacheError(t *testing.T) { + var ( + callCount atomic.Int32 + server *httptest.Server + ) + + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + + callCount.Add(1) + http.NotFound(w, r) + })) + defer server.Close() + + tr := &Translator{GatewayControllerName: "gateway.envoyproxy.io/gatewayclass-controller"} + tr.oidcDiscoveryCache = newOIDCDiscoveryCache() + + cfg, err := tr.fetchEndpointsFromIssuer(server.URL, nil) + require.Error(t, err) + require.Nil(t, cfg) + require.Equal(t, int32(1), callCount.Load()) + + cfgCached, err := tr.fetchEndpointsFromIssuer(server.URL, nil) + require.Error(t, err) + require.Nil(t, cfgCached) + require.Equal(t, int32(1), callCount.Load(), "second fetch should use cached error") + + cfgAfter, err := tr.fetchEndpointsFromIssuer(server.URL, nil) + require.Error(t, err) + require.Nil(t, cfgAfter) + require.Equal(t, int32(1), callCount.Load(), "subsequent fetch should continue using cached error") +} + // / tiny helper to build a minimal SecurityPolicy func sp(ns, name string) *egv1a1.SecurityPolicy { return &egv1a1.SecurityPolicy{ diff --git a/internal/gatewayapi/translator.go b/internal/gatewayapi/translator.go index 1a38a6ff242..358fa33be65 100644 --- a/internal/gatewayapi/translator.go +++ b/internal/gatewayapi/translator.go @@ -110,6 +110,9 @@ type Translator struct { // and reuses the specified value. ListenerPortShiftDisabled bool + // oidcDiscoveryCache is the cache for OIDC configurations discovered from issuer's well-known URL. + oidcDiscoveryCache *oidcDiscoveryCache + // Logger is the logger used by the translator. Logger logging.Logger }