diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index d7e4f0fa05..9bc7cff7c5 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -19,6 +19,7 @@ package main import ( "context" "flag" + "io/ioutil" "net/http" "os" "os/signal" @@ -30,6 +31,7 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/logging" "github.com/cortexlabs/cortex/pkg/lib/telemetry" "github.com/cortexlabs/cortex/pkg/proxy" + "github.com/cortexlabs/cortex/pkg/proxy/probe" "github.com/cortexlabs/cortex/pkg/types/clusterconfig" "github.com/cortexlabs/cortex/pkg/types/userconfig" "go.uber.org/zap" @@ -40,45 +42,24 @@ const ( _requestSampleInterval = 1 * time.Second ) -var ( - proxyLogger = logging.GetLogger() -) - -func Exit(err error, wrapStrs ...string) { - for _, str := range wrapStrs { - err = errors.Wrap(err, str) - } - - if err != nil && !errors.IsNoTelemetry(err) { - telemetry.Error(err) - } - - if err != nil && !errors.IsNoPrint(err) { - proxyLogger.Error(err) - } - - telemetry.Close() - - os.Exit(1) -} - func main() { var ( port int - metricsPort int + adminPort int userContainerPort int maxConcurrency int maxQueueLength int + probeDefPath string clusterConfigPath string ) - flag.IntVar(&port, "port", 8888, "port where the proxy is served") - flag.IntVar(&metricsPort, "metrics-port", 15000, "metrics port for prometheus") - flag.IntVar(&userContainerPort, "user-port", 8080, "port where the proxy redirects to the traffic to") + flag.IntVar(&port, "port", 8000, "port where the proxy server will be exposed") + flag.IntVar(&adminPort, "admin-port", 15000, "port where the admin server (for metrics and probes) will be exposed") + flag.IntVar(&userContainerPort, "user-port", 8080, "port where the proxy will redirect to the traffic to") flag.IntVar(&maxConcurrency, "max-concurrency", 0, "max concurrency allowed for user container") flag.IntVar(&maxQueueLength, "max-queue-length", 0, "max request queue length for user container") flag.StringVar(&clusterConfigPath, "cluster-config", "", "cluster config path") - + flag.StringVar(&probeDefPath, "probe", "", "path to the desired probe json definition") flag.Parse() log := logging.GetLogger() @@ -88,26 +69,26 @@ func main() { switch { case maxConcurrency == 0: - log.Fatal("-max-concurrency flag is required") + log.Fatal("--max-concurrency flag is required") case maxQueueLength == 0: - log.Fatal("-max-queue-length flag is required") + log.Fatal("--max-queue-length flag is required") case clusterConfigPath == "": - log.Fatal("-cluster-config flag is required") + log.Fatal("--cluster-config flag is required") } clusterConfig, err := clusterconfig.NewForFile(clusterConfigPath) if err != nil { - Exit(err) + exit(log, err) } awsClient, err := aws.NewForRegion(clusterConfig.Region) if err != nil { - Exit(err) + exit(log, err) } _, userID, err := awsClient.CheckCredentials() if err != nil { - Exit(err) + exit(log, err) } err = telemetry.Init(telemetry.Config{ @@ -122,7 +103,7 @@ func main() { BackoffMode: telemetry.BackoffDuplicateMessages, }) if err != nil { - Exit(err) + exit(log, err) } target := "http://127.0.0.1:" + strconv.Itoa(userContainerPort) @@ -139,6 +120,23 @@ func main() { promStats := proxy.NewPrometheusStatsReporter() + var readinessProbe *probe.Probe + if probeDefPath != "" { + jsonProbe, err := ioutil.ReadFile(probeDefPath) + if err != nil { + log.Fatal(err) + } + + probeDef, err := probe.DecodeJSON(string(jsonProbe)) + if err != nil { + log.Fatal(err) + } + + readinessProbe = probe.NewProbe(probeDef, log) + } else { + readinessProbe = probe.NewDefaultProbe(target, log) + } + go func() { reportTicker := time.NewTicker(_reportInterval) defer reportTicker.Stop() @@ -161,14 +159,18 @@ func main() { } }() + adminHandler := http.NewServeMux() + adminHandler.Handle("/metrics", promStats) + adminHandler.Handle("/healthz", probe.Handler(readinessProbe)) + servers := map[string]*http.Server{ "proxy": { - Addr: ":" + strconv.Itoa(port), + Addr: ":" + strconv.Itoa(userContainerPort), Handler: proxy.Handler(breaker, httpProxy), }, - "metrics": { - Addr: ":" + strconv.Itoa(metricsPort), - Handler: promStats, + "admin": { + Addr: ":" + strconv.Itoa(adminPort), + Handler: adminHandler, }, } @@ -184,8 +186,8 @@ func main() { signal.Notify(sigint, os.Interrupt) select { - case err := <-errCh: - Exit(errors.Wrap(err, "failed to start proxy server")) + case err = <-errCh: + exit(log, errors.Wrap(err, "failed to start proxy server")) case <-sigint: // We received an interrupt signal, shut down. log.Info("Received TERM signal, handling a graceful shutdown...") @@ -202,3 +204,20 @@ func main() { telemetry.Close() } } + +func exit(log *zap.SugaredLogger, err error, wrapStrs ...string) { + for _, str := range wrapStrs { + err = errors.Wrap(err, str) + } + + if err != nil && !errors.IsNoTelemetry(err) { + telemetry.Error(err) + } + + if err != nil && !errors.IsNoPrint(err) { + log.Error(err) + } + + telemetry.Close() + os.Exit(1) +} diff --git a/pkg/proxy/consts.go b/pkg/proxy/consts.go index 95c415378c..67bb86f7fc 100644 --- a/pkg/proxy/consts.go +++ b/pkg/proxy/consts.go @@ -17,14 +17,11 @@ limitations under the License. package proxy const ( - _userAgentKey = "User-Agent" + // UserAgentKey is the user agent header key + UserAgentKey = "User-Agent" + // KubeProbeUserAgentPrefix is the user agent header prefix used in k8s probes // Since K8s 1.8, prober requests have // User-Agent = "kube-probe/{major-version}.{minor-version}". - _kubeProbeUserAgentPrefix = "kube-probe/" - - // KubeletProbeHeaderName is the header name to augment the probes, because - // Istio with mTLS rewrites probes, but their probes pass a different - // user-agent. - _kubeletProbeHeaderName = "K-Kubelet-Probe" + KubeProbeUserAgentPrefix = "kube-probe/" ) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index ed59616f99..39ba5f0b6f 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -43,6 +43,5 @@ func Handler(breaker *Breaker, next http.Handler) http.HandlerFunc { } func isKubeletProbe(r *http.Request) bool { - return strings.HasPrefix(r.Header.Get(_userAgentKey), _kubeProbeUserAgentPrefix) || - r.Header.Get(_kubeletProbeHeaderName) != "" + return strings.HasPrefix(r.Header.Get(UserAgentKey), KubeProbeUserAgentPrefix) } diff --git a/pkg/proxy/probe/encoding.go b/pkg/proxy/probe/encoding.go new file mode 100644 index 0000000000..38363d1dc4 --- /dev/null +++ b/pkg/proxy/probe/encoding.go @@ -0,0 +1,46 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package probe + +import ( + "encoding/json" + "errors" + + kcore "k8s.io/api/core/v1" +) + +// DecodeJSON takes a json serialised *kcore.Probe and returns a Probe or an error. +func DecodeJSON(jsonProbe string) (*kcore.Probe, error) { + pb := &kcore.Probe{} + if err := json.Unmarshal([]byte(jsonProbe), pb); err != nil { + return nil, err + } + return pb, nil +} + +// EncodeJSON takes *kcore.Probe object and returns marshalled Probe JSON string and an error. +func EncodeJSON(pb *kcore.Probe) (string, error) { + if pb == nil { + return "", errors.New("cannot encode nil probe") + } + + probeJSON, err := json.Marshal(pb) + if err != nil { + return "", err + } + return string(probeJSON), nil +} diff --git a/pkg/proxy/probe/encoding_test.go b/pkg/proxy/probe/encoding_test.go new file mode 100644 index 0000000000..e48aa62165 --- /dev/null +++ b/pkg/proxy/probe/encoding_test.go @@ -0,0 +1,90 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package probe_test + +import ( + "encoding/json" + "testing" + + "github.com/cortexlabs/cortex/pkg/proxy/probe" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + kcore "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" +) + +func TestDecodeProbeSuccess(t *testing.T) { + t.Parallel() + + expectedProbe := &kcore.Probe{ + PeriodSeconds: 1, + TimeoutSeconds: 2, + SuccessThreshold: 1, + FailureThreshold: 1, + Handler: kcore.Handler{ + TCPSocket: &kcore.TCPSocketAction{ + Host: "127.0.0.1", + Port: intstr.FromString("8080"), + }, + }, + } + probeBytes, err := json.Marshal(expectedProbe) + require.NoError(t, err) + + gotProbe, err := probe.DecodeJSON(string(probeBytes)) + require.NoError(t, err) + + require.Equal(t, expectedProbe, gotProbe) +} + +func TestDecodeProbeFailure(t *testing.T) { + t.Parallel() + + probeBytes, err := json.Marshal("blah") + require.NoError(t, err) + + _, err = probe.DecodeJSON(string(probeBytes)) + require.Error(t, err) +} + +func TestEncodeProbe(t *testing.T) { + t.Parallel() + + pb := &kcore.Probe{ + SuccessThreshold: 1, + Handler: kcore.Handler{ + TCPSocket: &kcore.TCPSocketAction{ + Host: "127.0.0.1", + Port: intstr.FromString("8080"), + }, + }, + } + + jsonProbe, err := probe.EncodeJSON(pb) + require.NoError(t, err) + + wantProbe := `{"tcpSocket":{"port":"8080","host":"127.0.0.1"},"successThreshold":1}` + require.Equal(t, wantProbe, jsonProbe) +} + +func TestEncodeNilProbe(t *testing.T) { + t.Parallel() + + jsonProbe, err := probe.EncodeJSON(nil) + assert.Error(t, err) + assert.Empty(t, jsonProbe) +} diff --git a/pkg/proxy/probe/handler.go b/pkg/proxy/probe/handler.go new file mode 100644 index 0000000000..55c89b5c58 --- /dev/null +++ b/pkg/proxy/probe/handler.go @@ -0,0 +1,33 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package probe + +import "net/http" + +func Handler(pb *Probe) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + healthy := pb.ProbeContainer() + if !healthy { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("unhealthy")) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("healthy")) + } +} diff --git a/pkg/proxy/probe/handler_test.go b/pkg/proxy/probe/handler_test.go new file mode 100644 index 0000000000..da2db383f9 --- /dev/null +++ b/pkg/proxy/probe/handler_test.go @@ -0,0 +1,117 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package probe_test + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/cortexlabs/cortex/pkg/proxy" + "github.com/cortexlabs/cortex/pkg/proxy/probe" + "github.com/stretchr/testify/require" + kcore "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" +) + +func TestHandlerFailure(t *testing.T) { + t.Parallel() + log := newLogger(t) + + pb := probe.NewDefaultProbe("http://127.0.0.1:12345", log) + handler := probe.Handler(pb) + + r := httptest.NewRequest(http.MethodGet, "http://fake.cortex.dev/healthz", nil) + w := httptest.NewRecorder() + + handler(w, r) + + require.Equal(t, http.StatusInternalServerError, w.Code) + require.Equal(t, "unhealthy", w.Body.String()) +} + +func TestHandlerSuccessTCP(t *testing.T) { + t.Parallel() + log := newLogger(t) + + var userHandler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + server := httptest.NewServer(userHandler) + + pb := probe.NewDefaultProbe(server.URL, log) + handler := probe.Handler(pb) + + r := httptest.NewRequest(http.MethodGet, "http://fake.cortex.dev/healthz", nil) + w := httptest.NewRecorder() + + handler(w, r) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, "healthy", w.Body.String()) +} + +func TestHandlerSuccessHTTP(t *testing.T) { + t.Parallel() + log := newLogger(t) + + headers := []kcore.HTTPHeader{ + { + Name: "X-Cortex-Blah", + Value: "Blah", + }, + } + + var userHandler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + require.Contains(t, r.Header.Get(proxy.UserAgentKey), proxy.KubeProbeUserAgentPrefix) + for _, header := range headers { + require.Equal(t, header.Value, r.Header.Get(header.Name)) + } + + w.WriteHeader(http.StatusOK) + } + server := httptest.NewServer(userHandler) + targetURL, err := url.Parse(server.URL) + require.NoError(t, err) + + pb := probe.NewProbe( + &kcore.Probe{ + Handler: kcore.Handler{ + HTTPGet: &kcore.HTTPGetAction{ + Path: "/", + Port: intstr.FromString(targetURL.Port()), + Host: targetURL.Hostname(), + HTTPHeaders: headers, + }, + }, + TimeoutSeconds: 3, + PeriodSeconds: 1, + SuccessThreshold: 1, + FailureThreshold: 3, + }, log, + ) + handler := probe.Handler(pb) + + r := httptest.NewRequest(http.MethodGet, "http://fake.cortex.dev/healthz", nil) + w := httptest.NewRecorder() + + handler(w, r) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, "healthy", w.Body.String()) +} diff --git a/pkg/proxy/probe/probe.go b/pkg/proxy/probe/probe.go new file mode 100644 index 0000000000..3eeee7a9f9 --- /dev/null +++ b/pkg/proxy/probe/probe.go @@ -0,0 +1,142 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package probe + +import ( + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "time" + + s "github.com/cortexlabs/cortex/pkg/lib/strings" + "github.com/cortexlabs/cortex/pkg/proxy" + "go.uber.org/zap" + kcore "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" +) + +const ( + _defaultTimeoutSeconds = 1 +) + +type Probe struct { + *kcore.Probe + logger *zap.SugaredLogger +} + +func NewProbe(probe *kcore.Probe, logger *zap.SugaredLogger) *Probe { + return &Probe{ + Probe: probe, + logger: logger, + } +} + +func NewDefaultProbe(target string, logger *zap.SugaredLogger) *Probe { + targetURL, err := url.Parse(target) + if err != nil { + panic(fmt.Sprintf("failed to parse target URL: %v", err)) + } + + return &Probe{ + Probe: &kcore.Probe{ + Handler: kcore.Handler{ + TCPSocket: &kcore.TCPSocketAction{ + Port: intstr.FromString(targetURL.Port()), + Host: targetURL.Hostname(), + }, + }, + TimeoutSeconds: _defaultTimeoutSeconds, + }, + logger: logger, + } +} + +func (p *Probe) ProbeContainer() bool { + var err error + + switch { + case p.HTTPGet != nil: + err = p.httpProbe() + case p.TCPSocket != nil: + err = p.tcpProbe() + case p.Exec != nil: + // Should never be reachable. + p.logger.Error("exec probe not supported") + return false + default: + p.logger.Warn("no probe found") + return false + } + + if err != nil { + p.logger.Warn(err) + return false + } + return true +} + +func (p *Probe) httpProbe() error { + targetURL := s.EnsurePrefix( + net.JoinHostPort(p.HTTPGet.Host, p.HTTPGet.Port.String())+s.EnsurePrefix(p.HTTPGet.Path, "/"), + "http://", + ) + + httpClient := &http.Client{} + req, err := http.NewRequest(http.MethodGet, targetURL, nil) + if err != nil { + return err + } + + req.Header.Add(proxy.UserAgentKey, proxy.KubeProbeUserAgentPrefix) + + for _, header := range p.HTTPGet.HTTPHeaders { + req.Header.Add(header.Name, header.Value) + } + + res, err := httpClient.Do(req) + if err != nil { + return err + } + + defer func() { + // Ensure body is both read _and_ closed so it can be reused for keep-alive. + // No point handling errors, connection just won't be reused. + _, _ = io.Copy(ioutil.Discard, res.Body) + _ = res.Body.Close() + }() + + // response status code between 200-399 indicates success + if !(res.StatusCode >= 200 && res.StatusCode < 400) { + return fmt.Errorf("HTTP probe did not respond Ready, got status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Probe) tcpProbe() error { + timeout := time.Duration(p.TimeoutSeconds) * time.Second + address := net.JoinHostPort(p.TCPSocket.Host, p.TCPSocket.Port.String()) + conn, err := net.DialTimeout("tcp", address, timeout) + if err != nil { + return err + } + _ = conn.Close() + return nil +} diff --git a/pkg/proxy/probe/probe_test.go b/pkg/proxy/probe/probe_test.go new file mode 100644 index 0000000000..b0518396f3 --- /dev/null +++ b/pkg/proxy/probe/probe_test.go @@ -0,0 +1,113 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package probe_test + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/cortexlabs/cortex/pkg/proxy/probe" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + kcore "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" +) + +func newLogger(t *testing.T) *zap.SugaredLogger { + t.Helper() + + config := zap.NewDevelopmentConfig() + config.Level = zap.NewAtomicLevelAt(zap.FatalLevel) + logger, err := config.Build() + require.NoError(t, err) + + log := logger.Sugar() + + return log +} + +func TestDefaultProbeSuccess(t *testing.T) { + t.Parallel() + log := newLogger(t) + + var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + server := httptest.NewServer(handler) + pb := probe.NewDefaultProbe(server.URL, log) + + require.True(t, pb.ProbeContainer()) +} + +func TestDefaultProbeFailure(t *testing.T) { + t.Parallel() + log := newLogger(t) + + target := "http://127.0.0.1:12345" + pb := probe.NewDefaultProbe(target, log) + + require.False(t, pb.ProbeContainer()) +} + +func TestProbeHTTPFailure(t *testing.T) { + t.Parallel() + log := newLogger(t) + + pb := probe.NewProbe( + &kcore.Probe{ + Handler: kcore.Handler{ + HTTPGet: &kcore.HTTPGetAction{ + Path: "/healthz", + Port: intstr.FromString("12345"), + Host: "127.0.0.1", + }, + }, + TimeoutSeconds: 3, + }, log, + ) + + require.False(t, pb.ProbeContainer()) +} + +func TestProbeHTTPSuccess(t *testing.T) { + t.Parallel() + log := newLogger(t) + + var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + server := httptest.NewServer(handler) + targetURL, err := url.Parse(server.URL) + require.NoError(t, err) + + pb := probe.NewProbe( + &kcore.Probe{ + Handler: kcore.Handler{ + HTTPGet: &kcore.HTTPGetAction{ + Path: "/healthz", + Port: intstr.FromString(targetURL.Port()), + Host: targetURL.Hostname(), + }, + }, + TimeoutSeconds: 3, + }, log, + ) + + require.True(t, pb.ProbeContainer()) +}