Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions internal/gatewayapi/securitypolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
defaultForwardAccessToken = false
defaultRefreshToken = true
defaultPassThroughAuthHeader = false
defaultOIDCHTTPTimeout = 5 * time.Second

// nolint: gosec
oidcHMACSecretName = "envoy-oidc-hmac"
Expand All @@ -55,6 +56,10 @@ func (t *Translator) ProcessSecurityPolicies(securityPolicies []*egv1a1.Security
resources *resource.Resources,
xdsIR resource.XdsIRMap,
) []*egv1a1.SecurityPolicy {
// Cache is only reused during one translation across multiple routes and gateways.
// The failed fetches will be retried in the next translation when the provider resources are reconciled again.
t.oidcDiscoveryCache = newOIDCDiscoveryCache()

// SecurityPolicies are already sorted by the provider layer

// First build a map out of the routes and gateways for faster lookup since users might have thousands of routes or more.
Expand Down Expand Up @@ -1419,7 +1424,7 @@ func (t *Translator) buildOIDCProvider(policy *egv1a1.SecurityPolicy, resources
// EG assumes that the issuer url uses the same protocol and CA as the token endpoint.
// If we need to support different protocols or CAs, we need to add more fields to the OIDCProvider CRD.
if provider.TokenEndpoint == nil || provider.AuthorizationEndpoint == nil {
discoveredConfig, err := fetchEndpointsFromIssuer(provider.Issuer, providerTLS)
discoveredConfig, err := t.fetchEndpointsFromIssuer(provider.Issuer, providerTLS)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1507,7 +1512,25 @@ func (o *OpenIDConfig) validate() error {
return nil
}

func fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfig) (*OpenIDConfig, error) {
func (t *Translator) fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfig) (*OpenIDConfig, error) {
if config, cachedErr, ok := t.oidcDiscoveryCache.Get(issuerURL); ok {
if cachedErr != nil {
return nil, cachedErr
}
return config, nil
}

config, err := discoverEndpointsFromIssuer(issuerURL, providerTLS)
if err != nil {
t.oidcDiscoveryCache.Set(issuerURL, nil, err)
return nil, err
}

t.oidcDiscoveryCache.Set(issuerURL, config, nil)
return config, nil
}

func discoverEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfig) (*OpenIDConfig, error) {
var (
tlsConfig *tls.Config
err error
Expand All @@ -1519,7 +1542,7 @@ func fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfi
}
}

client := &http.Client{}
client := &http.Client{Timeout: defaultOIDCHTTPTimeout}
if tlsConfig != nil {
client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
Expand Down Expand Up @@ -1564,6 +1587,47 @@ func fetchEndpointsFromIssuer(issuerURL string, providerTLS *ir.TLSUpstreamConfi
return &config, nil
}

// oidcDiscoveryCache is a cache for auto-discovered OIDC configurations from the issuer's well-known URL.
// The cache is only used within the current translation, so no need to lock it or expire entries.
type oidcDiscoveryCache struct {
entries map[string]cachedOIDCEntry
}

type cachedOIDCEntry struct {
config *OpenIDConfig
err error
}

func newOIDCDiscoveryCache() *oidcDiscoveryCache {
return &oidcDiscoveryCache{
entries: make(map[string]cachedOIDCEntry),
}
}

func (c *oidcDiscoveryCache) Get(issuer string) (*OpenIDConfig, error, bool) {
if c == nil {
return nil, nil, false
}

entry, ok := c.entries[issuer]
if !ok {
return nil, nil, false
}

return entry.config, entry.err, true
}

func (c *oidcDiscoveryCache) Set(issuer string, cfg *OpenIDConfig, err error) {
if c == nil {
return
}

c.entries[issuer] = cachedOIDCEntry{
config: cfg,
err: err,
}
}

func retryable(code int) bool {
return code >= 500 &&
(code != http.StatusNotImplemented &&
Expand Down
77 changes: 77 additions & 0 deletions internal/gatewayapi/securitypolicy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
package gatewayapi

import (
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -780,6 +784,79 @@ func TestValidateCIDRs_ErrorOnBadCIDR(t *testing.T) {
}
}

func TestTranslatorFetchEndpointsFromIssuerCache(t *testing.T) {
var (
callCount atomic.Int32
server *httptest.Server
)

server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}

callCount.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = fmt.Fprintf(w, `{"token_endpoint":%q,"authorization_endpoint":%q}`, server.URL+"/token", server.URL+"/authorize")
}))
defer server.Close()

tr := &Translator{GatewayControllerName: "gateway.envoyproxy.io/gatewayclass-controller"}
tr.oidcDiscoveryCache = newOIDCDiscoveryCache()

cfg, err := tr.fetchEndpointsFromIssuer(server.URL, nil)
require.NoError(t, err)
require.NotNil(t, cfg)
require.Equal(t, int32(1), callCount.Load())

cfgCached, err := tr.fetchEndpointsFromIssuer(server.URL, nil)
require.NoError(t, err)
require.NotNil(t, cfgCached)
require.Equal(t, int32(1), callCount.Load(), "second fetch should use cache")

cfgAgain, err := tr.fetchEndpointsFromIssuer(server.URL, nil)
require.NoError(t, err)
require.NotNil(t, cfgAgain)
require.Equal(t, int32(1), callCount.Load(), "subsequent fetch should continue using cache")
}

func TestTranslatorFetchEndpointsFromIssuerCacheError(t *testing.T) {
var (
callCount atomic.Int32
server *httptest.Server
)

server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}

callCount.Add(1)
http.NotFound(w, r)
}))
defer server.Close()

tr := &Translator{GatewayControllerName: "gateway.envoyproxy.io/gatewayclass-controller"}
tr.oidcDiscoveryCache = newOIDCDiscoveryCache()

cfg, err := tr.fetchEndpointsFromIssuer(server.URL, nil)
require.Error(t, err)
require.Nil(t, cfg)
require.Equal(t, int32(1), callCount.Load())

cfgCached, err := tr.fetchEndpointsFromIssuer(server.URL, nil)
require.Error(t, err)
require.Nil(t, cfgCached)
require.Equal(t, int32(1), callCount.Load(), "second fetch should use cached error")

cfgAfter, err := tr.fetchEndpointsFromIssuer(server.URL, nil)
require.Error(t, err)
require.Nil(t, cfgAfter)
require.Equal(t, int32(1), callCount.Load(), "subsequent fetch should continue using cached error")
}

// / tiny helper to build a minimal SecurityPolicy
func sp(ns, name string) *egv1a1.SecurityPolicy {
return &egv1a1.SecurityPolicy{
Expand Down
3 changes: 3 additions & 0 deletions internal/gatewayapi/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ type Translator struct {
// and reuses the specified value.
ListenerPortShiftDisabled bool

// oidcDiscoveryCache is the cache for OIDC configurations discovered from issuer's well-known URL.
oidcDiscoveryCache *oidcDiscoveryCache

// Logger is the logger used by the translator.
Logger logging.Logger
}
Expand Down
Loading