diff --git a/.circleci/config.yml b/.circleci/config.yml index 93c51130ac..b2a3748456 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -146,15 +146,27 @@ jobs: echo 'export AWS_SECRET_ACCESS_KEY=${NIGHTLY_AWS_SECRET_ACCESS_KEY}' >> $BASH_ENV - run: name: Generate Cluster Config + # using a variety of node groups to test the multi-instance-type cluster functionality command: | cat \<< EOF > ./cluster.yaml - cluster_name: cortex provider: aws + cluster_name: cortex region: us-east-1 - instance_type: g4dn.xlarge - min_instances: 1 - max_instances: 2 bucket: cortex-dev-nightly + node_groups: + - name: spot + instance_type: t3.medium + min_instances: 0 + max_instances: 1 + spot: true + - name: cpu + instance_type: c5.xlarge + min_instances: 1 + max_instances: 2 + - name: gpu + instance_type: g4dn.xlarge + min_instances: 1 + max_instances: 2 EOF - run-e2e-tests: provider: aws @@ -174,16 +186,28 @@ jobs: echo 'export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/google_service_account.json' >> $BASH_ENV - run: name: Generate Cluster Config + # using a variety of node pools to test the multi-instance-type cluster functionality command: | cat \<< EOF > ./cluster.yaml + provider: gcp cluster_name: cortex project: cortexlabs-dev zone: us-east1-c - provider: gcp - instance_type: n1-standard-2 - accelerator_type: nvidia-tesla-t4 - min_instances: 1 - max_instances: 2 + node_pools: + - name: preemptible + instance_type: n1-standard-2 + min_instances: 0 + max_instances: 1 + preemptible: true + - name: cpu + instance_type: n1-standard-2 + min_instances: 1 + max_instances: 2 + - name: gpu + instance_type: n1-standard-2 + accelerator_type: nvidia-tesla-t4 + min_instances: 1 + max_instances: 2 EOF - run-e2e-tests: provider: gcp diff --git a/cli/cmd/cluster.go b/cli/cmd/cluster.go index 3230265a2e..cab21e17f7 100644 --- a/cli/cmd/cluster.go +++ b/cli/cmd/cluster.go @@ -302,7 +302,7 @@ var _clusterConfigureCmd = &cobra.Command{ exit.Error(err) } - accessConfig, err := getClusterAccessConfigWithCache() + accessConfig, err := getNewClusterAccessConfig(clusterConfigFile) if err != nil { exit.Error(err) } @@ -317,7 +317,7 @@ var _clusterConfigureCmd = &cobra.Command{ exit.Error(err) } - err = clusterstate.AssertClusterStatus(accessConfig.ClusterName, accessConfig.Region, clusterState.Status, clusterstate.StatusCreateComplete) + err = clusterstate.AssertClusterStatus(accessConfig.ClusterName, accessConfig.Region, clusterState.Status, clusterstate.StatusCreateComplete, clusterstate.StatusUpdateComplete, clusterstate.StatusUpdateRollbackComplete) if err != nil { exit.Error(err) } @@ -527,7 +527,7 @@ var _clusterExportCmd = &cobra.Command{ exit.Error(err) } - err = clusterstate.AssertClusterStatus(accessConfig.ClusterName, accessConfig.Region, clusterState.Status, clusterstate.StatusCreateComplete) + err = clusterstate.AssertClusterStatus(accessConfig.ClusterName, accessConfig.Region, clusterState.Status, clusterstate.StatusCreateComplete, clusterstate.StatusUpdateComplete, clusterstate.StatusUpdateRollbackComplete) if err != nil { exit.Error(err) } @@ -668,7 +668,7 @@ func printInfoClusterState(awsClient *aws.Client, accessConfig *clusterconfig.Ac fmt.Println() } - err = clusterstate.AssertClusterStatus(accessConfig.ClusterName, accessConfig.Region, clusterState.Status, clusterstate.StatusCreateComplete) + err = clusterstate.AssertClusterStatus(accessConfig.ClusterName, accessConfig.Region, clusterState.Status, clusterstate.StatusCreateComplete, clusterstate.StatusUpdateComplete, clusterstate.StatusUpdateRollbackComplete) if err != nil { return err } @@ -679,6 +679,12 @@ func printInfoClusterState(awsClient *aws.Client, accessConfig *clusterconfig.Ac func printInfoOperatorResponse(clusterConfig clusterconfig.Config, operatorEndpoint string) error { fmt.Print("fetching cluster status ...\n\n") + yamlBytes, err := yaml.Marshal(clusterConfig) + if err != nil { + return err + } + yamlString := string(yamlBytes) + operatorConfig := cluster.OperatorConfig{ Telemetry: isTelemetryEnabled(), ClientID: clientID(), @@ -688,42 +694,67 @@ func printInfoOperatorResponse(clusterConfig clusterconfig.Config, operatorEndpo infoResponse, err := cluster.Info(operatorConfig) if err != nil { - fmt.Println(clusterConfig.UserStr()) + fmt.Println(yamlString) return err } infoResponse.ClusterConfig.Config = clusterConfig - printInfoClusterConfig(infoResponse) + fmt.Println(console.Bold("metadata:")) + fmt.Println(fmt.Sprintf("aws access key id: %s", infoResponse.MaskedAWSAccessKeyID)) + fmt.Println(fmt.Sprintf("%s: %s", clusterconfig.APIVersionUserKey, infoResponse.ClusterConfig.APIVersion)) + + fmt.Println() + fmt.Println(console.Bold("cluster config:")) + fmt.Print(yamlString) + printInfoPricing(infoResponse, clusterConfig) printInfoNodes(infoResponse) return nil } -func printInfoClusterConfig(infoResponse *schema.InfoResponse) { - var items table.KeyValuePairs - items.Add("aws access key id", infoResponse.MaskedAWSAccessKeyID) - items.AddAll(infoResponse.ClusterConfig.UserTable()) - items.Print() -} - func printInfoPricing(infoResponse *schema.InfoResponse, clusterConfig clusterconfig.Config) { - numAPIInstances := len(infoResponse.NodeInfos) - - var totalAPIInstancePrice float64 - for _, nodeInfo := range infoResponse.NodeInfos { - totalAPIInstancePrice += nodeInfo.Price - } - eksPrice := aws.EKSPrices[clusterConfig.Region] operatorInstancePrice := aws.InstanceMetadatas[clusterConfig.Region]["t3.medium"].Price operatorEBSPrice := aws.EBSMetadatas[clusterConfig.Region]["gp2"].PriceGB * 20 / 30 / 24 metricsEBSPrice := aws.EBSMetadatas[clusterConfig.Region]["gp2"].PriceGB * 40 / 30 / 24 nlbPrice := aws.NLBMetadatas[clusterConfig.Region].Price natUnitPrice := aws.NATMetadatas[clusterConfig.Region].Price - apiEBSPrice := aws.EBSMetadatas[clusterConfig.Region][clusterConfig.InstanceVolumeType.String()].PriceGB * float64(clusterConfig.InstanceVolumeSize) / 30 / 24 - if clusterConfig.InstanceVolumeType.String() == "io1" && clusterConfig.InstanceVolumeIOPS != nil { - apiEBSPrice += aws.EBSMetadatas[clusterConfig.Region][clusterConfig.InstanceVolumeType.String()].PriceIOPS * float64(*clusterConfig.InstanceVolumeIOPS) / 30 / 24 + + headers := []table.Header{ + {Title: "aws resource"}, + {Title: "cost per hour"}, + } + + var rows [][]interface{} + rows = append(rows, []interface{}{"1 eks cluster", s.DollarsMaxPrecision(eksPrice)}) + + var totalNodeGroupsPrice float64 + for _, ng := range clusterConfig.NodeGroups { + var ngNamePrefix string + if ng.Spot { + ngNamePrefix = "cx-ws-" + } else { + ngNamePrefix = "cx-wd-" + } + nodesInfo := infoResponse.GetNodesWithNodeGroupName(ngNamePrefix + ng.Name) + numInstances := len(nodesInfo) + + ebsPrice := aws.EBSMetadatas[clusterConfig.Region][ng.InstanceVolumeType.String()].PriceGB * float64(ng.InstanceVolumeSize) / 30 / 24 + if ng.InstanceVolumeType.String() == "io1" && ng.InstanceVolumeIOPS != nil { + ebsPrice += aws.EBSMetadatas[clusterConfig.Region][ng.InstanceVolumeType.String()].PriceIOPS * float64(*ng.InstanceVolumeIOPS) / 30 / 24 + } + totalEBSPrice := ebsPrice * float64(numInstances) + + totalInstancePrice := float64(0) + for _, nodeInfo := range nodesInfo { + totalInstancePrice += nodeInfo.Price + } + + rows = append(rows, []interface{}{fmt.Sprintf("nodegroup %s: %d (out of %d) %s for your apis", ng.Name, numInstances, ng.MaxInstances, s.PluralS("instance", numInstances)), s.DollarsAndTenthsOfCents(totalInstancePrice) + " total"}) + rows = append(rows, []interface{}{fmt.Sprintf("nodegroup %s: %d (out of %d) %dgb ebs %s for your apis", ng.Name, numInstances, ng.MaxInstances, ng.InstanceVolumeSize, s.PluralS("volume", numInstances)), s.DollarsAndTenthsOfCents(totalEBSPrice) + " total"}) + + totalNodeGroupsPrice += totalEBSPrice + totalInstancePrice } var natTotalPrice float64 @@ -732,20 +763,9 @@ func printInfoPricing(infoResponse *schema.InfoResponse, clusterConfig clusterco } else if clusterConfig.NATGateway == clusterconfig.HighlyAvailableNATGateway { natTotalPrice = natUnitPrice * float64(len(clusterConfig.AvailabilityZones)) } - - totalPrice := eksPrice + totalAPIInstancePrice + apiEBSPrice*float64(numAPIInstances) + - operatorInstancePrice*2 + operatorEBSPrice + metricsEBSPrice + nlbPrice*2 + natTotalPrice + totalPrice := eksPrice + totalNodeGroupsPrice + operatorInstancePrice*2 + operatorEBSPrice + metricsEBSPrice + nlbPrice*2 + natTotalPrice fmt.Printf(console.Bold("\nyour cluster currently costs %s per hour\n\n"), s.DollarsAndCents(totalPrice)) - headers := []table.Header{ - {Title: "aws resource"}, - {Title: "cost per hour"}, - } - - var rows [][]interface{} - rows = append(rows, []interface{}{"1 eks cluster", s.DollarsMaxPrecision(eksPrice)}) - rows = append(rows, []interface{}{fmt.Sprintf("%d %s for your apis", numAPIInstances, s.PluralS("instance", numAPIInstances)), s.DollarsAndTenthsOfCents(totalAPIInstancePrice) + " total"}) - rows = append(rows, []interface{}{fmt.Sprintf("%d %dgb ebs %s for your apis", numAPIInstances, clusterConfig.InstanceVolumeSize, s.PluralS("volume", numAPIInstances)), s.DollarsAndTenthsOfCents(apiEBSPrice*float64(numAPIInstances)) + " total"}) rows = append(rows, []interface{}{"2 t3.medium instances for cortex", s.DollarsMaxPrecision(operatorInstancePrice * 2)}) rows = append(rows, []interface{}{"1 20gb ebs volume for the operator", s.DollarsAndTenthsOfCents(operatorEBSPrice)}) rows = append(rows, []interface{}{"1 40gb ebs volume for prometheus", s.DollarsAndTenthsOfCents(metricsEBSPrice)}) diff --git a/cli/cmd/cluster_gcp.go b/cli/cmd/cluster_gcp.go index e6db19ff7e..3c549ff82c 100644 --- a/cli/cmd/cluster_gcp.go +++ b/cli/cmd/cluster_gcp.go @@ -35,6 +35,7 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/telemetry" "github.com/cortexlabs/cortex/pkg/types" "github.com/cortexlabs/cortex/pkg/types/clusterconfig" + "github.com/cortexlabs/yaml" "github.com/spf13/cobra" containerpb "google.golang.org/genproto/googleapis/container/v1" ) @@ -373,7 +374,18 @@ func printInfoOperatorResponseGCP(accessConfig *clusterconfig.GCPAccessConfig, o return err } - infoResponse.ClusterConfig.UserTable().Print() + yamlBytes, err := yaml.Marshal(infoResponse.ClusterConfig.GCPConfig) + if err != nil { + return err + } + yamlString := string(yamlBytes) + + fmt.Println(console.Bold("metadata:")) + fmt.Println(fmt.Sprintf("%s: %s", clusterconfig.APIVersionUserKey, infoResponse.ClusterConfig.APIVersion)) + + fmt.Println() + fmt.Println(console.Bold("cluster config:")) + fmt.Print(yamlString) return nil } @@ -448,25 +460,9 @@ func updateGCPCLIEnv(envName string, operatorEndpoint string, disallowPrompt boo func createGKECluster(clusterConfig *clusterconfig.GCPConfig, gcpClient *gcp.Client) error { fmt.Print("○ creating GKE cluster ") - nodeLabels := map[string]string{"workload": "true"} - var accelerators []*containerpb.AcceleratorConfig - - if clusterConfig.AcceleratorType != nil { - accelerators = append(accelerators, &containerpb.AcceleratorConfig{ - AcceleratorCount: *clusterConfig.AcceleratorsPerInstance, - AcceleratorType: *clusterConfig.AcceleratorType, - }) - nodeLabels["nvidia.com/gpu"] = "present" - } - gkeClusterParent := fmt.Sprintf("projects/%s/locations/%s", clusterConfig.Project, clusterConfig.Zone) gkeClusterName := fmt.Sprintf("%s/clusters/%s", gkeClusterParent, clusterConfig.ClusterName) - initialNodeCount := int64(1) - if clusterConfig.MinInstances > 0 { - initialNodeCount = clusterConfig.MinInstances - } - gkeClusterConfig := containerpb.Cluster{ Name: clusterConfig.ClusterName, InitialClusterVersion: "1.18", @@ -488,52 +484,68 @@ func createGKECluster(clusterConfig *clusterconfig.GCPConfig, gcpClient *gcp.Cli Locations: []string{clusterConfig.Zone}, } - if clusterConfig.Preemptible { - gkeClusterConfig.NodePools = append(gkeClusterConfig.NodePools, &containerpb.NodePool{ - Name: "ng-cortex-wk-preemp", - Config: &containerpb.NodeConfig{ - MachineType: clusterConfig.InstanceType, - Labels: nodeLabels, - Taints: []*containerpb.NodeTaint{ - { - Key: "workload", - Value: "true", - Effect: containerpb.NodeTaint_NO_SCHEDULE, + for _, nodePool := range clusterConfig.NodePools { + nodeLabels := map[string]string{"workload": "true"} + initialNodeCount := int64(1) + if nodePool.MinInstances > 0 { + initialNodeCount = nodePool.MinInstances + } + + var accelerators []*containerpb.AcceleratorConfig + if nodePool.AcceleratorType != nil { + accelerators = append(accelerators, &containerpb.AcceleratorConfig{ + AcceleratorCount: *nodePool.AcceleratorsPerInstance, + AcceleratorType: *nodePool.AcceleratorType, + }) + nodeLabels["nvidia.com/gpu"] = "present" + } + + if nodePool.Preemptible { + gkeClusterConfig.NodePools = append(gkeClusterConfig.NodePools, &containerpb.NodePool{ + Name: "cx-ws-" + nodePool.Name, + Config: &containerpb.NodeConfig{ + MachineType: nodePool.InstanceType, + Labels: nodeLabels, + Taints: []*containerpb.NodeTaint{ + { + Key: "workload", + Value: "true", + Effect: containerpb.NodeTaint_NO_SCHEDULE, + }, }, - }, - Accelerators: accelerators, - OauthScopes: []string{ - "https://www.googleapis.com/auth/compute", - "https://www.googleapis.com/auth/devstorage.read_only", - }, - ServiceAccount: gcpClient.ClientEmail, - Preemptible: true, - }, - InitialNodeCount: int32(initialNodeCount), - }) - } - if clusterConfig.OnDemandBackup || !clusterConfig.Preemptible { - gkeClusterConfig.NodePools = append(gkeClusterConfig.NodePools, &containerpb.NodePool{ - Name: "ng-cortex-wk-on-dmd", - Config: &containerpb.NodeConfig{ - MachineType: clusterConfig.InstanceType, - Labels: nodeLabels, - Taints: []*containerpb.NodeTaint{ - { - Key: "workload", - Value: "true", - Effect: containerpb.NodeTaint_NO_SCHEDULE, + Accelerators: accelerators, + OauthScopes: []string{ + "https://www.googleapis.com/auth/compute", + "https://www.googleapis.com/auth/devstorage.read_only", }, + ServiceAccount: gcpClient.ClientEmail, + Preemptible: true, }, - Accelerators: accelerators, - OauthScopes: []string{ - "https://www.googleapis.com/auth/compute", - "https://www.googleapis.com/auth/devstorage.read_only", + InitialNodeCount: int32(initialNodeCount), + }) + } else { + gkeClusterConfig.NodePools = append(gkeClusterConfig.NodePools, &containerpb.NodePool{ + Name: "cx-wd-" + nodePool.Name, + Config: &containerpb.NodeConfig{ + MachineType: nodePool.InstanceType, + Labels: nodeLabels, + Taints: []*containerpb.NodeTaint{ + { + Key: "workload", + Value: "true", + Effect: containerpb.NodeTaint_NO_SCHEDULE, + }, + }, + Accelerators: accelerators, + OauthScopes: []string{ + "https://www.googleapis.com/auth/compute", + "https://www.googleapis.com/auth/devstorage.read_only", + }, + ServiceAccount: gcpClient.ClientEmail, }, - ServiceAccount: gcpClient.ClientEmail, - }, - InitialNodeCount: int32(initialNodeCount), - }) + InitialNodeCount: int32(initialNodeCount), + }) + } } if clusterConfig.Network != nil { diff --git a/cli/cmd/lib_cluster_config_aws.go b/cli/cmd/lib_cluster_config_aws.go index 0eb6772fb1..c2bde4b09d 100644 --- a/cli/cmd/lib_cluster_config_aws.go +++ b/cli/cmd/lib_cluster_config_aws.go @@ -20,7 +20,6 @@ import ( "fmt" "path" "path/filepath" - "reflect" "regexp" "github.com/cortexlabs/cortex/pkg/consts" @@ -28,9 +27,9 @@ import ( cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/files" + "github.com/cortexlabs/cortex/pkg/lib/maps" "github.com/cortexlabs/cortex/pkg/lib/pointer" "github.com/cortexlabs/cortex/pkg/lib/prompt" - "github.com/cortexlabs/cortex/pkg/lib/sets/strset" s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/lib/table" "github.com/cortexlabs/cortex/pkg/types/clusterconfig" @@ -135,7 +134,7 @@ func getInstallClusterConfig(awsClient *aws.Client, clusterConfigFile string, di return nil, err } - err = clusterConfig.Validate(awsClient) + err = clusterConfig.Validate(awsClient, false) if err != nil { err = errors.Append(err, fmt.Sprintf("\n\ncluster configuration schema can be found at https://docs.cortex.dev/v/%s/", consts.CortexVersionMinor)) return nil, errors.Wrap(err, clusterConfigFile) @@ -163,245 +162,54 @@ func getConfigureClusterConfig(cachedClusterConfig clusterconfig.Config, cluster } promptIfNotAdmin(awsClient, disallowPrompt) - err = setConfigFieldsFromCached(userClusterConfig, &cachedClusterConfig) - if err != nil { - return nil, err - } - userClusterConfig.Telemetry, err = readTelemetryConfig() if err != nil { return nil, err } - err = userClusterConfig.Validate(awsClient) + err = userClusterConfig.Validate(awsClient, true) if err != nil { err = errors.Append(err, fmt.Sprintf("\n\ncluster configuration schema can be found at https://docs.cortex.dev/v/%s/", consts.CortexVersionMinor)) return nil, errors.Wrap(err, clusterConfigFile) } - confirmConfigureClusterConfig(*userClusterConfig, disallowPrompt) - - return userClusterConfig, nil -} - -func setConfigFieldsFromCached(userClusterConfig *clusterconfig.Config, cachedClusterConfig *clusterconfig.Config) error { - if userClusterConfig.Bucket != "" && userClusterConfig.Bucket != cachedClusterConfig.Bucket { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.BucketKey, cachedClusterConfig.Bucket) - } - userClusterConfig.Bucket = cachedClusterConfig.Bucket - - if userClusterConfig.InstanceType != cachedClusterConfig.InstanceType { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.InstanceTypeKey, cachedClusterConfig.InstanceType) - } - userClusterConfig.InstanceType = cachedClusterConfig.InstanceType - - if _, ok := userClusterConfig.Tags[clusterconfig.ClusterNameTag]; !ok { - userClusterConfig.Tags[clusterconfig.ClusterNameTag] = userClusterConfig.ClusterName - } - if !reflect.DeepEqual(userClusterConfig.Tags, cachedClusterConfig.Tags) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.TagsKey, s.ObjFlat(cachedClusterConfig.Tags)) - } - - // The user doesn't have to specify AZs in their config - if len(userClusterConfig.AvailabilityZones) > 0 { - if !strset.New(userClusterConfig.AvailabilityZones...).IsEqual(strset.New(cachedClusterConfig.AvailabilityZones...)) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.AvailabilityZonesKey, cachedClusterConfig.AvailabilityZones) - } - } - userClusterConfig.AvailabilityZones = cachedClusterConfig.AvailabilityZones - - if len(userClusterConfig.Subnets) > 0 || len(cachedClusterConfig.Subnets) > 0 { - if !reflect.DeepEqual(userClusterConfig.Subnets, cachedClusterConfig.Subnets) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SubnetsKey, cachedClusterConfig.Subnets) - } - } - userClusterConfig.Subnets = cachedClusterConfig.Subnets - - if s.Obj(cachedClusterConfig.SSLCertificateARN) != s.Obj(userClusterConfig.SSLCertificateARN) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SSLCertificateARNKey, cachedClusterConfig.SSLCertificateARN) - } - userClusterConfig.SSLCertificateARN = cachedClusterConfig.SSLCertificateARN - - if userClusterConfig.CortexPolicyARN != "" && cachedClusterConfig.CortexPolicyARN != userClusterConfig.CortexPolicyARN { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.CortexPolicyARNKey, cachedClusterConfig.CortexPolicyARN) - } - userClusterConfig.CortexPolicyARN = cachedClusterConfig.CortexPolicyARN - - if !strset.New(cachedClusterConfig.IAMPolicyARNs...).IsEqual(strset.New(userClusterConfig.IAMPolicyARNs...)) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.IAMPolicyARNsKey, cachedClusterConfig.IAMPolicyARNs) - } - userClusterConfig.IAMPolicyARNs = cachedClusterConfig.IAMPolicyARNs - - if userClusterConfig.InstanceVolumeSize != cachedClusterConfig.InstanceVolumeSize { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.InstanceVolumeSizeKey, cachedClusterConfig.InstanceVolumeSize) - } - userClusterConfig.InstanceVolumeSize = cachedClusterConfig.InstanceVolumeSize - - if userClusterConfig.InstanceVolumeType != cachedClusterConfig.InstanceVolumeType { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.InstanceVolumeTypeKey, cachedClusterConfig.InstanceVolumeType) - } - userClusterConfig.InstanceVolumeType = cachedClusterConfig.InstanceVolumeType - - if userClusterConfig.InstanceVolumeIOPS != cachedClusterConfig.InstanceVolumeIOPS { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.InstanceVolumeIOPSKey, cachedClusterConfig.InstanceVolumeIOPS) - } - userClusterConfig.InstanceVolumeIOPS = cachedClusterConfig.InstanceVolumeIOPS - - if userClusterConfig.SubnetVisibility != cachedClusterConfig.SubnetVisibility { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SubnetVisibilityKey, cachedClusterConfig.SubnetVisibility) - } - userClusterConfig.SubnetVisibility = cachedClusterConfig.SubnetVisibility - - if userClusterConfig.NATGateway != cachedClusterConfig.NATGateway { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.NATGatewayKey, cachedClusterConfig.NATGateway) - } - userClusterConfig.NATGateway = cachedClusterConfig.NATGateway - - if userClusterConfig.APILoadBalancerScheme != cachedClusterConfig.APILoadBalancerScheme { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.APILoadBalancerSchemeKey, cachedClusterConfig.APILoadBalancerScheme) - } - userClusterConfig.APILoadBalancerScheme = cachedClusterConfig.APILoadBalancerScheme - - if userClusterConfig.OperatorLoadBalancerScheme != cachedClusterConfig.OperatorLoadBalancerScheme { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.OperatorLoadBalancerSchemeKey, cachedClusterConfig.OperatorLoadBalancerScheme) - } - userClusterConfig.OperatorLoadBalancerScheme = cachedClusterConfig.OperatorLoadBalancerScheme - - if s.Obj(cachedClusterConfig.VPCCIDR) != s.Obj(userClusterConfig.VPCCIDR) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.VPCCIDRKey, cachedClusterConfig.VPCCIDR) - } - userClusterConfig.VPCCIDR = cachedClusterConfig.VPCCIDR - - if s.Obj(cachedClusterConfig.ImageDownloader) != s.Obj(userClusterConfig.ImageDownloader) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageDownloaderKey, cachedClusterConfig.ImageDownloader) - } - userClusterConfig.ImageDownloader = cachedClusterConfig.ImageDownloader - - if s.Obj(cachedClusterConfig.ImageRequestMonitor) != s.Obj(userClusterConfig.ImageRequestMonitor) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageRequestMonitorKey, cachedClusterConfig.ImageRequestMonitor) - } - userClusterConfig.ImageRequestMonitor = cachedClusterConfig.ImageRequestMonitor - - if s.Obj(cachedClusterConfig.ImageClusterAutoscaler) != s.Obj(userClusterConfig.ImageClusterAutoscaler) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageClusterAutoscalerKey, cachedClusterConfig.ImageClusterAutoscaler) - } - userClusterConfig.ImageClusterAutoscaler = cachedClusterConfig.ImageClusterAutoscaler - - if s.Obj(cachedClusterConfig.ImageMetricsServer) != s.Obj(userClusterConfig.ImageMetricsServer) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageMetricsServerKey, cachedClusterConfig.ImageMetricsServer) - } - userClusterConfig.ImageMetricsServer = cachedClusterConfig.ImageMetricsServer - - if s.Obj(cachedClusterConfig.ImageInferentia) != s.Obj(userClusterConfig.ImageInferentia) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageInferentiaKey, cachedClusterConfig.ImageInferentia) - } - userClusterConfig.ImageInferentia = cachedClusterConfig.ImageInferentia - - if s.Obj(cachedClusterConfig.ImageNeuronRTD) != s.Obj(userClusterConfig.ImageNeuronRTD) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageNeuronRTDKey, cachedClusterConfig.ImageNeuronRTD) - } - userClusterConfig.ImageNeuronRTD = cachedClusterConfig.ImageNeuronRTD - - if s.Obj(cachedClusterConfig.ImageNvidia) != s.Obj(userClusterConfig.ImageNvidia) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageNvidiaKey, cachedClusterConfig.ImageNvidia) - } - userClusterConfig.ImageNvidia = cachedClusterConfig.ImageNvidia - - if s.Obj(cachedClusterConfig.ImageFluentBit) != s.Obj(userClusterConfig.ImageFluentBit) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageFluentBitKey, cachedClusterConfig.ImageFluentBit) - } - userClusterConfig.ImageFluentBit = cachedClusterConfig.ImageFluentBit - - if s.Obj(cachedClusterConfig.ImageIstioProxy) != s.Obj(userClusterConfig.ImageIstioProxy) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageIstioProxyKey, cachedClusterConfig.ImageIstioProxy) - } - userClusterConfig.ImageIstioProxy = cachedClusterConfig.ImageIstioProxy - - if s.Obj(cachedClusterConfig.ImageIstioPilot) != s.Obj(userClusterConfig.ImageIstioPilot) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageIstioPilotKey, cachedClusterConfig.ImageIstioPilot) - } - userClusterConfig.ImageIstioPilot = cachedClusterConfig.ImageIstioPilot - - if s.Obj(cachedClusterConfig.ImagePrometheus) != s.Obj(userClusterConfig.ImagePrometheus) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImagePrometheusKey, cachedClusterConfig.ImagePrometheus) - } - - if s.Obj(cachedClusterConfig.ImagePrometheusConfigReloader) != s.Obj(userClusterConfig.ImagePrometheusConfigReloader) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImagePrometheusConfigReloaderKey, cachedClusterConfig.ImagePrometheusConfigReloader) - } - - if s.Obj(cachedClusterConfig.ImagePrometheusOperator) != s.Obj(userClusterConfig.ImagePrometheusOperator) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImagePrometheusOperatorKey, cachedClusterConfig.ImagePrometheusOperator) - } - - if s.Obj(cachedClusterConfig.ImagePrometheusStatsDExporter) != s.Obj(userClusterConfig.ImagePrometheusStatsDExporter) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImagePrometheusStatsDExporterKey, cachedClusterConfig.ImagePrometheusStatsDExporter) - } - - if s.Obj(cachedClusterConfig.ImagePrometheusDCGMExporter) != s.Obj(userClusterConfig.ImagePrometheusDCGMExporter) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImagePrometheusDCGMExporterKey, cachedClusterConfig.ImagePrometheusDCGMExporter) - } - - if s.Obj(cachedClusterConfig.ImagePrometheusKubeStateMetrics) != s.Obj(userClusterConfig.ImagePrometheusKubeStateMetrics) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImagePrometheusKubeStateMetricsKey, cachedClusterConfig.ImagePrometheusKubeStateMetrics) + clusterConfigCopy, err := userClusterConfig.DeepCopy() + if err != nil { + return nil, err } - if s.Obj(cachedClusterConfig.ImagePrometheusNodeExporter) != s.Obj(userClusterConfig.ImagePrometheusNodeExporter) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImagePrometheusNodeExporterKey, cachedClusterConfig.ImagePrometheusNodeExporter) + cachedConfigCopy, err := cachedClusterConfig.DeepCopy() + if err != nil { + return nil, err } - if s.Obj(cachedClusterConfig.ImageKubeRBACProxy) != s.Obj(userClusterConfig.ImageKubeRBACProxy) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageKubeRBACProxyKey, cachedClusterConfig.ImageKubeRBACProxy) + for idx := range clusterConfigCopy.NodeGroups { + clusterConfigCopy.NodeGroups[idx].MinInstances = 0 + clusterConfigCopy.NodeGroups[idx].MaxInstances = 0 } - - if s.Obj(cachedClusterConfig.ImageGrafana) != s.Obj(userClusterConfig.ImageGrafana) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageGrafanaKey, cachedClusterConfig.ImageGrafana) + for idx := range cachedConfigCopy.NodeGroups { + cachedConfigCopy.NodeGroups[idx].MinInstances = 0 + cachedConfigCopy.NodeGroups[idx].MaxInstances = 0 } - if s.Obj(cachedClusterConfig.ImageEventExporter) != s.Obj(userClusterConfig.ImageEventExporter) { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.ImageEventExporterKey, cachedClusterConfig.ImageEventExporter) + h1, err := clusterConfigCopy.Hash() + if err != nil { + return nil, err } - - if userClusterConfig.Spot != cachedClusterConfig.Spot { - return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SpotKey, cachedClusterConfig.Spot) + h2, err := cachedConfigCopy.Hash() + if err != nil { + return nil, err } - userClusterConfig.Spot = cachedClusterConfig.Spot - if userClusterConfig.Spot { - userClusterConfig.FillEmptySpotFields() + if h1 != h2 { + return nil, clusterconfig.ErrorConfigCannotBeChangedOnUpdate() } - if userClusterConfig.SpotConfig != nil && s.Obj(userClusterConfig.SpotConfig) != s.Obj(cachedClusterConfig.SpotConfig) { - if cachedClusterConfig.SpotConfig == nil { - return clusterconfig.ErrorConfiguredWhenSpotIsNotEnabled(clusterconfig.SpotConfigKey) - } - - if !strset.New(userClusterConfig.SpotConfig.InstanceDistribution...).IsEqual(strset.New(cachedClusterConfig.SpotConfig.InstanceDistribution...)) { - return errors.Wrap(clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.InstanceDistributionKey, cachedClusterConfig.SpotConfig.InstanceDistribution), clusterconfig.SpotConfigKey) - } - - if userClusterConfig.SpotConfig.OnDemandBaseCapacity != nil && *userClusterConfig.SpotConfig.OnDemandBaseCapacity != *cachedClusterConfig.SpotConfig.OnDemandBaseCapacity { - return errors.Wrap(clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.OnDemandBaseCapacityKey, *cachedClusterConfig.SpotConfig.OnDemandBaseCapacity), clusterconfig.SpotConfigKey) - } - - if userClusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity != nil && *userClusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity != *cachedClusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity { - return errors.Wrap(clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.OnDemandPercentageAboveBaseCapacityKey, *cachedClusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity), clusterconfig.SpotConfigKey) - } - - if userClusterConfig.SpotConfig.MaxPrice != nil && *userClusterConfig.SpotConfig.MaxPrice != *cachedClusterConfig.SpotConfig.MaxPrice { - return errors.Wrap(clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.MaxPriceKey, *cachedClusterConfig.SpotConfig.MaxPrice), clusterconfig.SpotConfigKey) - } - - if userClusterConfig.SpotConfig.InstancePools != nil && *userClusterConfig.SpotConfig.InstancePools != *cachedClusterConfig.SpotConfig.InstancePools { - return errors.Wrap(clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.InstancePoolsKey, *cachedClusterConfig.SpotConfig.InstancePools), clusterconfig.SpotConfigKey) - } - - if userClusterConfig.SpotConfig.OnDemandBackup != cachedClusterConfig.SpotConfig.OnDemandBackup { - return errors.Wrap(clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.OnDemandBackupKey, cachedClusterConfig.SpotConfig.OnDemandBackup), clusterconfig.SpotConfigKey) - } + if !disallowPrompt { + exitMessage := fmt.Sprintf("cluster configuration can be modified via the cluster config file; see https://docs.cortex.dev/v/%s/ for more information", consts.CortexVersionMinor) + prompt.YesOrExit(fmt.Sprintf("your cluster named \"%s\" in %s will be updated according to the configuration above, are you sure you want to continue?", userClusterConfig.ClusterName, userClusterConfig.Region), "", exitMessage) } - userClusterConfig.SpotConfig = cachedClusterConfig.SpotConfig - return nil + return userClusterConfig, nil } func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsClient *aws.Client, disallowPrompt bool) { @@ -411,11 +219,6 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsClient metricsEBSPrice := aws.EBSMetadatas[clusterConfig.Region]["gp2"].PriceGB * 40 / 30 / 24 nlbPrice := aws.NLBMetadatas[clusterConfig.Region].Price natUnitPrice := aws.NATMetadatas[clusterConfig.Region].Price - apiInstancePrice := aws.InstanceMetadatas[clusterConfig.Region][clusterConfig.InstanceType].Price - apiEBSPrice := aws.EBSMetadatas[clusterConfig.Region][clusterConfig.InstanceVolumeType.String()].PriceGB * float64(clusterConfig.InstanceVolumeSize) / 30 / 24 - if clusterConfig.InstanceVolumeType.String() == "io1" && clusterConfig.InstanceVolumeIOPS != nil { - apiEBSPrice += aws.EBSMetadatas[clusterConfig.Region][clusterConfig.InstanceVolumeType.String()].PriceIOPS * float64(*clusterConfig.InstanceVolumeIOPS) / 30 / 24 - } var natTotalPrice float64 if clusterConfig.NATGateway == clusterconfig.SingleNATGateway { @@ -424,10 +227,6 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsClient natTotalPrice = natUnitPrice * float64(len(clusterConfig.AvailabilityZones)) } - fixedPrice := eksPrice + 2*operatorInstancePrice + operatorEBSPrice + metricsEBSPrice + 2*nlbPrice + natTotalPrice - totalMinPrice := fixedPrice + float64(clusterConfig.MinInstances)*(apiInstancePrice+apiEBSPrice) - totalMaxPrice := fixedPrice + float64(clusterConfig.MaxInstances)*(apiInstancePrice+apiEBSPrice) - headers := []table.Header{ {Title: "aws resource"}, {Title: "cost per hour"}, @@ -436,31 +235,48 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsClient var rows [][]interface{} rows = append(rows, []interface{}{"1 eks cluster", s.DollarsMaxPrecision(eksPrice)}) - instanceStr := "instances" - volumeStr := "volumes" - if clusterConfig.MinInstances == 1 && clusterConfig.MaxInstances == 1 { - instanceStr = "instance" - volumeStr = "volume" - } - workerInstanceStr := fmt.Sprintf("%d - %d %s %s for your apis", clusterConfig.MinInstances, clusterConfig.MaxInstances, clusterConfig.InstanceType, instanceStr) - ebsInstanceStr := fmt.Sprintf("%d - %d %dgb ebs %s for your apis", clusterConfig.MinInstances, clusterConfig.MaxInstances, clusterConfig.InstanceVolumeSize, volumeStr) - if clusterConfig.MinInstances == clusterConfig.MaxInstances { - workerInstanceStr = fmt.Sprintf("%d %s %s for your apis", clusterConfig.MinInstances, clusterConfig.InstanceType, instanceStr) - ebsInstanceStr = fmt.Sprintf("%d %dgb ebs %s for your apis", clusterConfig.MinInstances, clusterConfig.InstanceVolumeSize, volumeStr) - } - - workerPriceStr := s.DollarsMaxPrecision(apiInstancePrice) + " each" - if clusterConfig.Spot { - spotPrice, err := awsClient.SpotInstancePrice(clusterConfig.InstanceType) - workerPriceStr += " (spot pricing unavailable)" - if err == nil && spotPrice != 0 { - workerPriceStr = fmt.Sprintf("%s - %s each (varies based on spot price)", s.DollarsMaxPrecision(spotPrice), s.DollarsMaxPrecision(apiInstancePrice)) - totalMinPrice = fixedPrice + float64(clusterConfig.MinInstances)*(spotPrice+apiEBSPrice) + ngNameToSpotInstancesUsed := map[string]int{} + fixedPrice := eksPrice + 2*operatorInstancePrice + operatorEBSPrice + metricsEBSPrice + 2*nlbPrice + natTotalPrice + totalMinPrice := fixedPrice + totalMaxPrice := fixedPrice + for _, ng := range clusterConfig.NodeGroups { + apiInstancePrice := aws.InstanceMetadatas[clusterConfig.Region][ng.InstanceType].Price + apiEBSPrice := aws.EBSMetadatas[clusterConfig.Region][ng.InstanceVolumeType.String()].PriceGB * float64(ng.InstanceVolumeSize) / 30 / 24 + if ng.InstanceVolumeType.String() == "io1" && ng.InstanceVolumeIOPS != nil { + apiEBSPrice += aws.EBSMetadatas[clusterConfig.Region][ng.InstanceVolumeType.String()].PriceIOPS * float64(*ng.InstanceVolumeIOPS) / 30 / 24 + } + + totalMinPrice += float64(ng.MinInstances) * (apiInstancePrice + apiEBSPrice) + totalMaxPrice += float64(ng.MaxInstances) * (apiInstancePrice + apiEBSPrice) + + instanceStr := "instances" + volumeStr := "volumes" + if ng.MinInstances == 1 && ng.MaxInstances == 1 { + instanceStr = "instance" + volumeStr = "volume" + } + workerInstanceStr := fmt.Sprintf("nodegroup %s: %d - %d %s %s for your apis", ng.Name, ng.MinInstances, ng.MaxInstances, ng.InstanceType, instanceStr) + ebsInstanceStr := fmt.Sprintf("nodegroup %s: %d - %d %dgb ebs %s for your apis", ng.Name, ng.MinInstances, ng.MaxInstances, ng.InstanceVolumeSize, volumeStr) + if ng.MinInstances == ng.MaxInstances { + workerInstanceStr = fmt.Sprintf("nodegroup %s: %d %s %s for your apis", ng.Name, ng.MinInstances, ng.InstanceType, instanceStr) + ebsInstanceStr = fmt.Sprintf("nodegroup %s:%d %dgb ebs %s for your apis", ng.Name, ng.MinInstances, ng.InstanceVolumeSize, volumeStr) } + + workerPriceStr := s.DollarsMaxPrecision(apiInstancePrice) + " each" + if ng.Spot { + ngNameToSpotInstancesUsed[ng.Name]++ + spotPrice, err := awsClient.SpotInstancePrice(ng.InstanceType) + workerPriceStr += " (spot pricing unavailable)" + if err == nil && spotPrice != 0 { + workerPriceStr = fmt.Sprintf("%s - %s each (varies based on spot price)", s.DollarsMaxPrecision(spotPrice), s.DollarsMaxPrecision(apiInstancePrice)) + totalMinPrice = fixedPrice + float64(ng.MinInstances)*(spotPrice+apiEBSPrice) + } + } + + rows = append(rows, []interface{}{workerInstanceStr, workerPriceStr}) + rows = append(rows, []interface{}{ebsInstanceStr, s.DollarsAndTenthsOfCents(apiEBSPrice) + " each"}) } - rows = append(rows, []interface{}{workerInstanceStr, workerPriceStr}) - rows = append(rows, []interface{}{ebsInstanceStr, s.DollarsAndTenthsOfCents(apiEBSPrice) + " each"}) rows = append(rows, []interface{}{"2 t3.medium instances for cortex", s.DollarsMaxPrecision(operatorInstancePrice * 2)}) rows = append(rows, []interface{}{"1 20gb ebs volume for the operator", s.DollarsAndTenthsOfCents(operatorEBSPrice)}) rows = append(rows, []interface{}{"1 40gb ebs volume for prometheus", s.DollarsAndTenthsOfCents(metricsEBSPrice)}) @@ -478,16 +294,15 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsClient } fmt.Println(items.MustFormat(&table.Opts{Sort: pointer.Bool(false)})) - suffix := "" priceStr := s.DollarsAndCents(totalMaxPrice) - + suffix := "" if totalMinPrice != totalMaxPrice { priceStr = fmt.Sprintf("%s - %s", s.DollarsAndCents(totalMinPrice), s.DollarsAndCents(totalMaxPrice)) - if clusterConfig.Spot && clusterConfig.MinInstances != clusterConfig.MaxInstances { + if len(ngNameToSpotInstancesUsed) > 0 && len(ngNameToSpotInstancesUsed) < len(clusterConfig.NodeGroups) { suffix = " based on cluster size and spot instance pricing/availability" - } else if clusterConfig.Spot && clusterConfig.MinInstances == clusterConfig.MaxInstances { + } else if len(ngNameToSpotInstancesUsed) == len(clusterConfig.NodeGroups) { suffix = " based on spot instance pricing/availability" - } else if !clusterConfig.Spot && clusterConfig.MinInstances != clusterConfig.MaxInstances { + } else if len(ngNameToSpotInstancesUsed) == 0 { suffix = " based on cluster size" } } @@ -508,12 +323,8 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsClient fmt.Print("warning: you've configured your cluster to be installed in an existing VPC; if your cluster doesn't spin up or function as expected, please double-check your VPC configuration (here are the requirements: https://eksctl.io/usage/vpc-networking/#use-existing-vpc-other-custom-configuration)\n\n") } - if clusterConfig.Spot && clusterConfig.SpotConfig.OnDemandBackup != nil && !*clusterConfig.SpotConfig.OnDemandBackup { - if *clusterConfig.SpotConfig.OnDemandBaseCapacity == 0 && *clusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity == 0 { - fmt.Printf("warning: you've disabled on-demand instances (%s=0 and %s=0); spot instances are not guaranteed to be available so please take that into account for production clusters; see https://docs.cortex.dev/v/%s/ for more information\n\n", clusterconfig.OnDemandBaseCapacityKey, clusterconfig.OnDemandPercentageAboveBaseCapacityKey, consts.CortexVersionMinor) - } else { - fmt.Printf("warning: you've enabled spot instances; spot instances are not guaranteed to be available so please take that into account for production clusters; see https://docs.cortex.dev/v/%s/ for more information\n\n", consts.CortexVersionMinor) - } + if len(clusterConfig.NodeGroups) > 1 && len(ngNameToSpotInstancesUsed) > 0 { + fmt.Printf("warning: you've enabled spot instances for %s %s; spot instances are not guaranteed to be available so please take that into account for production clusters; see https://docs.cortex.dev/v/%s/ for more information\n\n", s.PluralS("nodegroup", len(ngNameToSpotInstancesUsed)), s.StrsAnd(maps.StrMapKeysInt(ngNameToSpotInstancesUsed)), consts.CortexVersionMinor) } if !disallowPrompt { @@ -521,161 +332,3 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsClient prompt.YesOrExit("would you like to continue?", "", exitMessage) } } - -func confirmConfigureClusterConfig(clusterConfig clusterconfig.Config, disallowPrompt bool) { - fmt.Println(clusterConfigConfirmationStr(clusterConfig)) - - if !disallowPrompt { - exitMessage := fmt.Sprintf("cluster configuration can be modified via the cluster config file; see https://docs.cortex.dev/v/%s/ for more information", consts.CortexVersionMinor) - prompt.YesOrExit(fmt.Sprintf("your cluster named \"%s\" in %s will be updated according to the configuration above, are you sure you want to continue?", clusterConfig.ClusterName, clusterConfig.Region), "", exitMessage) - } -} - -func clusterConfigConfirmationStr(clusterConfig clusterconfig.Config) string { - defaultConfig, _ := clusterconfig.GetDefaults() - - var items table.KeyValuePairs - - items.Add(clusterconfig.RegionUserKey, clusterConfig.Region) - if len(clusterConfig.AvailabilityZones) > 0 { - items.Add(clusterconfig.AvailabilityZonesUserKey, clusterConfig.AvailabilityZones) - } - for _, subnetConfig := range clusterConfig.Subnets { - items.Add("subnet in "+subnetConfig.AvailabilityZone, subnetConfig.SubnetID) - } - items.Add(clusterconfig.BucketUserKey, clusterConfig.Bucket) - items.Add(clusterconfig.ClusterNameUserKey, clusterConfig.ClusterName) - - items.Add(clusterconfig.InstanceTypeUserKey, clusterConfig.InstanceType) - items.Add(clusterconfig.MinInstancesUserKey, clusterConfig.MinInstances) - items.Add(clusterconfig.MaxInstancesUserKey, clusterConfig.MaxInstances) - items.Add(clusterconfig.TagsUserKey, s.ObjFlatNoQuotes(clusterConfig.Tags)) - if clusterConfig.SSLCertificateARN != nil { - items.Add(clusterconfig.SSLCertificateARNUserKey, *clusterConfig.SSLCertificateARN) - } - - items.Add(clusterconfig.CortexPolicyARNUserKey, clusterConfig.CortexPolicyARN) - items.Add(clusterconfig.IAMPolicyARNsUserKey, s.ObjFlatNoQuotes(clusterConfig.IAMPolicyARNs)) - - if clusterConfig.InstanceVolumeSize != defaultConfig.InstanceVolumeSize { - items.Add(clusterconfig.InstanceVolumeSizeUserKey, clusterConfig.InstanceVolumeSize) - } - if clusterConfig.InstanceVolumeType != defaultConfig.InstanceVolumeType { - items.Add(clusterconfig.InstanceVolumeTypeUserKey, clusterConfig.InstanceVolumeType) - } - if clusterConfig.InstanceVolumeIOPS != nil { - items.Add(clusterconfig.InstanceVolumeIOPSUserKey, *clusterConfig.InstanceVolumeIOPS) - } - - if clusterConfig.SubnetVisibility != defaultConfig.SubnetVisibility { - items.Add(clusterconfig.SubnetVisibilityUserKey, clusterConfig.SubnetVisibility) - } - if clusterConfig.NATGateway != defaultConfig.NATGateway { - items.Add(clusterconfig.NATGatewayUserKey, clusterConfig.NATGateway) - } - if clusterConfig.APILoadBalancerScheme != defaultConfig.APILoadBalancerScheme { - items.Add(clusterconfig.APILoadBalancerSchemeUserKey, clusterConfig.APILoadBalancerScheme) - } - if clusterConfig.OperatorLoadBalancerScheme != defaultConfig.OperatorLoadBalancerScheme { - items.Add(clusterconfig.OperatorLoadBalancerSchemeUserKey, clusterConfig.OperatorLoadBalancerScheme) - } - - if clusterConfig.Spot != defaultConfig.Spot { - items.Add(clusterconfig.SpotUserKey, s.YesNo(clusterConfig.Spot)) - - if clusterConfig.SpotConfig != nil { - defaultSpotConfig := clusterconfig.SpotConfig{} - clusterconfig.AutoGenerateSpotConfig(&defaultSpotConfig, clusterConfig.Region, clusterConfig.InstanceType) - - if !strset.New(clusterConfig.SpotConfig.InstanceDistribution...).IsEqual(strset.New(defaultSpotConfig.InstanceDistribution...)) { - items.Add(clusterconfig.InstanceDistributionUserKey, clusterConfig.SpotConfig.InstanceDistribution) - } - - if *clusterConfig.SpotConfig.OnDemandBaseCapacity != *defaultSpotConfig.OnDemandBaseCapacity { - items.Add(clusterconfig.OnDemandBaseCapacityUserKey, *clusterConfig.SpotConfig.OnDemandBaseCapacity) - } - - if *clusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity != *defaultSpotConfig.OnDemandPercentageAboveBaseCapacity { - items.Add(clusterconfig.OnDemandPercentageAboveBaseCapacityUserKey, *clusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity) - } - - if *clusterConfig.SpotConfig.MaxPrice != *defaultSpotConfig.MaxPrice { - items.Add(clusterconfig.MaxPriceUserKey, *clusterConfig.SpotConfig.MaxPrice) - } - - if *clusterConfig.SpotConfig.InstancePools != *defaultSpotConfig.InstancePools { - items.Add(clusterconfig.InstancePoolsUserKey, *clusterConfig.SpotConfig.InstancePools) - } - - if *clusterConfig.SpotConfig.OnDemandBackup != *defaultSpotConfig.OnDemandBackup { - items.Add(clusterconfig.OnDemandBackupUserKey, s.YesNo(*clusterConfig.SpotConfig.OnDemandBackup)) - } - } - } - - if clusterConfig.VPCCIDR != nil { - items.Add(clusterconfig.VPCCIDRUserKey, clusterConfig.VPCCIDR) - } - - if clusterConfig.Telemetry != defaultConfig.Telemetry { - items.Add(clusterconfig.TelemetryUserKey, clusterConfig.Telemetry) - } - if clusterConfig.ImageOperator != defaultConfig.ImageOperator { - items.Add(clusterconfig.ImageOperatorUserKey, clusterConfig.ImageOperator) - } - if clusterConfig.ImageManager != defaultConfig.ImageManager { - items.Add(clusterconfig.ImageManagerUserKey, clusterConfig.ImageManager) - } - if clusterConfig.ImageDownloader != defaultConfig.ImageDownloader { - items.Add(clusterconfig.ImageDownloaderUserKey, clusterConfig.ImageDownloader) - } - if clusterConfig.ImageRequestMonitor != defaultConfig.ImageRequestMonitor { - items.Add(clusterconfig.ImageRequestMonitorUserKey, clusterConfig.ImageRequestMonitor) - } - if clusterConfig.ImageClusterAutoscaler != defaultConfig.ImageClusterAutoscaler { - items.Add(clusterconfig.ImageClusterAutoscalerUserKey, clusterConfig.ImageClusterAutoscaler) - } - if clusterConfig.ImageMetricsServer != defaultConfig.ImageMetricsServer { - items.Add(clusterconfig.ImageMetricsServerUserKey, clusterConfig.ImageMetricsServer) - } - if clusterConfig.ImageInferentia != defaultConfig.ImageInferentia { - items.Add(clusterconfig.ImageInferentiaUserKey, clusterConfig.ImageInferentia) - } - if clusterConfig.ImageNeuronRTD != defaultConfig.ImageNeuronRTD { - items.Add(clusterconfig.ImageNeuronRTDUserKey, clusterConfig.ImageNeuronRTD) - } - if clusterConfig.ImageNvidia != defaultConfig.ImageNvidia { - items.Add(clusterconfig.ImageNvidiaUserKey, clusterConfig.ImageNvidia) - } - if clusterConfig.ImageFluentBit != defaultConfig.ImageFluentBit { - items.Add(clusterconfig.ImageFluentBitUserKey, clusterConfig.ImageFluentBit) - } - if clusterConfig.ImageIstioProxy != defaultConfig.ImageIstioProxy { - items.Add(clusterconfig.ImageIstioProxyUserKey, clusterConfig.ImageIstioProxy) - } - if clusterConfig.ImageIstioPilot != defaultConfig.ImageIstioPilot { - items.Add(clusterconfig.ImageIstioPilotUserKey, clusterConfig.ImageIstioPilot) - } - if clusterConfig.ImagePrometheus != defaultConfig.ImagePrometheus { - items.Add(clusterconfig.ImagePrometheusUserKey, clusterConfig.ImagePrometheus) - } - if clusterConfig.ImagePrometheusConfigReloader != defaultConfig.ImagePrometheusConfigReloader { - items.Add(clusterconfig.ImagePrometheusConfigReloaderUserKey, clusterConfig.ImagePrometheusConfigReloader) - } - if clusterConfig.ImagePrometheusOperator != defaultConfig.ImagePrometheusOperator { - items.Add(clusterconfig.ImagePrometheusOperatorUserKey, clusterConfig.ImagePrometheusOperator) - } - if clusterConfig.ImagePrometheusStatsDExporter != defaultConfig.ImagePrometheusStatsDExporter { - items.Add(clusterconfig.ImagePrometheusStatsDExporterUserKey, clusterConfig.ImagePrometheusStatsDExporter) - } - if clusterConfig.ImagePrometheusDCGMExporter != defaultConfig.ImagePrometheusDCGMExporter { - items.Add(clusterconfig.ImagePrometheusDCGMExporterUserKey, clusterConfig.ImagePrometheusDCGMExporter) - } - if clusterConfig.ImagePrometheusKubeStateMetrics != defaultConfig.ImagePrometheusKubeStateMetrics { - items.Add(clusterconfig.ImagePrometheusKubeStateMetricsUserKey, clusterConfig.ImagePrometheusKubeStateMetrics) - } - if clusterConfig.ImageGrafana != defaultConfig.ImageGrafana { - items.Add(clusterconfig.ImageGrafanaUserKey, clusterConfig.ImageGrafana) - } - return items.String() -} diff --git a/manager/debug.sh b/manager/debug.sh index b35ae9aa22..9484cb599e 100755 --- a/manager/debug.sh +++ b/manager/debug.sh @@ -51,37 +51,10 @@ echo -n "." mkdir -p /cortex-debug/aws/amis -asg_on_demand_info=$(aws autoscaling describe-auto-scaling-groups --region $CORTEX_REGION --query "AutoScalingGroups[?contains(Tags[?Key==\`alpha.eksctl.io/cluster-name\`].Value, \`$CORTEX_CLUSTER_NAME\`)]|[?contains(Tags[?Key==\`alpha.eksctl.io/nodegroup-name\`].Value, \`ng-cortex-worker-on-demand\`)]") +aws autoscaling describe-auto-scaling-groups --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asgs" 2>&1 echo -n "." -asg_on_demand_name="" -asg_on_demand_length=$(echo "$asg_on_demand_info" | jq -r 'length') -if (( "$asg_on_demand_length" > "0" )); then - asg_on_demand_name=$(echo "$asg_on_demand_info" | jq -r 'first | .AutoScalingGroupName') - aws autoscaling describe-auto-scaling-groups --auto-scaling-group-names $asg_on_demand_name --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asg-on-demand" 2>&1 - echo -n "." - aws autoscaling describe-scaling-activities --max-items 1000 --auto-scaling-group-name $asg_on_demand_name --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asg-activities-on-demand" 2>&1 - echo -n "." -fi - -asg_spot_info=$(aws autoscaling describe-auto-scaling-groups --region $CORTEX_REGION --query "AutoScalingGroups[?contains(Tags[?Key==\`alpha.eksctl.io/cluster-name\`].Value, \`$CORTEX_CLUSTER_NAME\`)]|[?contains(Tags[?Key==\`alpha.eksctl.io/nodegroup-name\`].Value, \`ng-cortex-worker-spot\`)]") +aws autoscaling describe-scaling-activities --max-items 1000 --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asg-activities" 2>&1 echo -n "." -asg_spot_name="" -asg_spot_length=$(echo "$asg_spot_info" | jq -r 'length') -if (( "$asg_spot_length" > "0" )); then - asg_spot_name=$(echo "$asg_spot_info" | jq -r 'first | .AutoScalingGroupName') - aws autoscaling describe-auto-scaling-groups --auto-scaling-group-names $asg_spot_name --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asg-spot" 2>&1 - echo -n "." - aws autoscaling describe-scaling-activities --max-items 1000 --auto-scaling-group-name $asg_spot_name --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asg-activities-spot" 2>&1 - echo -n "." -fi - -# failsafe in case the asg(s) could not be located -if [ "$asg_on_demand_name" == "" ] && [ "$asg_spot_name" == "" ]; then - aws autoscaling describe-auto-scaling-groups --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asgs" 2>&1 - echo -n "." - aws autoscaling describe-scaling-activities --max-items 1000 --region=$CORTEX_REGION --output json > "/cortex-debug/aws/asg-activities" 2>&1 - echo -n "." -fi aws ec2 describe-instances --filters Name=tag:cortex.dev/cluster-name,Values=$CORTEX_CLUSTER_NAME --region=$CORTEX_REGION --output json > "/cortex-debug/aws/instances" 2>&1 echo -n "." diff --git a/manager/generate_eks.py b/manager/generate_eks.py index 2f44d8be34..3dd59df684 100644 --- a/manager/generate_eks.py +++ b/manager/generate_eks.py @@ -56,9 +56,9 @@ def merge_override(a, b): return a -def apply_worker_settings(nodegroup): +def apply_worker_settings(nodegroup, config): worker_settings = { - "name": "ng-cortex-worker-on-demand", + "name": "cx-wd-" + config["name"], "labels": {"workload": "true"}, "taints": {"workload": "true:NoSchedule"}, "tags": { @@ -88,7 +88,7 @@ def apply_clusterconfig(nodegroup, config): def apply_spot_settings(nodegroup, config): spot_settings = { - "name": "ng-cortex-worker-spot", + "name": "cx-ws-" + config["name"], "instanceType": "mixed", "instancesDistribution": { "instanceTypes": config["spot_config"]["instance_distribution"], @@ -126,9 +126,8 @@ def is_gpu(instance_type): return instance_type.startswith("g") or instance_type.startswith("p") -def apply_inf_settings(nodegroup, cluster_config): - instance_type = cluster_config["instance_type"] - instance_region = cluster_config["region"] +def apply_inf_settings(nodegroup, config): + instance_type = config["instance_type"] num_chips, hugepages_mem = get_inf_resources(instance_type) inf_settings = { @@ -162,13 +161,34 @@ def get_inf_resources(instance_type): return num_chips, f"{128 * num_chips}Mi" +def get_all_worker_nodegroups(cluster_config: dict) -> list: + worker_nodegroups = [] + for ng in cluster_config["node_groups"]: + worker_nodegroup = default_nodegroup(cluster_config) + apply_worker_settings(worker_nodegroup, ng) + apply_clusterconfig(worker_nodegroup, ng) + + if ng["spot"]: + apply_spot_settings(worker_nodegroup, ng) + + if is_gpu(ng["instance_type"]): + apply_gpu_settings(worker_nodegroup) + + if is_inf(ng["instance_type"]): + apply_inf_settings(worker_nodegroup, ng) + + worker_nodegroups.append(worker_nodegroup) + + return worker_nodegroups + + def generate_eks(cluster_config_path): with open(cluster_config_path, "r") as f: cluster_config = yaml.safe_load(f) operator_nodegroup = default_nodegroup(cluster_config) operator_settings = { - "name": "ng-cortex-operator", + "name": "cx-operator", "instanceType": "t3.medium", "minSize": 2, "maxSize": 2, @@ -176,19 +196,7 @@ def generate_eks(cluster_config_path): } operator_nodegroup = merge_override(operator_nodegroup, operator_settings) - worker_nodegroup = default_nodegroup(cluster_config) - apply_worker_settings(worker_nodegroup) - - apply_clusterconfig(worker_nodegroup, cluster_config) - - if cluster_config["spot"]: - apply_spot_settings(worker_nodegroup, cluster_config) - - if is_gpu(cluster_config["instance_type"]): - apply_gpu_settings(worker_nodegroup) - - if is_inf(cluster_config["instance_type"]): - apply_inf_settings(worker_nodegroup, cluster_config) + worker_nodegroups = get_all_worker_nodegroups(cluster_config) nat_gateway = "Disable" if cluster_config["nat_gateway"] == "single": @@ -206,7 +214,7 @@ def generate_eks(cluster_config_path): "tags": cluster_config["tags"], }, "vpc": {"nat": {"gateway": nat_gateway}}, - "nodeGroups": [operator_nodegroup, worker_nodegroup], + "nodeGroups": [operator_nodegroup] + worker_nodegroups, } if ( @@ -230,22 +238,6 @@ def generate_eks(cluster_config_path): if cluster_config.get("vpc_cidr", "") != "": eks["vpc"]["cidr"] = cluster_config["vpc_cidr"] - if cluster_config.get("spot_config") is not None and cluster_config["spot_config"].get( - "on_demand_backup", False - ): - backup_nodegroup = default_nodegroup(cluster_config) - apply_worker_settings(backup_nodegroup) - apply_clusterconfig(backup_nodegroup, cluster_config) - if is_gpu(cluster_config["instance_type"]): - apply_gpu_settings(backup_nodegroup) - if is_inf(cluster_config["instance_type"]): - apply_inf_settings(backup_nodegroup, cluster_config) - - backup_nodegroup["minSize"] = 0 - backup_nodegroup["desiredCapacity"] = 0 - - eks["nodeGroups"].append(backup_nodegroup) - print(yaml.dump(eks, Dumper=IgnoreAliases, default_flow_style=False, default_style="")) diff --git a/manager/install.sh b/manager/install.sh index 087d09026b..6f7d8ef5b4 100755 --- a/manager/install.sh +++ b/manager/install.sh @@ -71,18 +71,14 @@ function cluster_up_aws() { setup_grafana echo "✓" - if [[ "$CORTEX_INSTANCE_TYPE" == p* ]] || [[ "$CORTEX_INSTANCE_TYPE" == g* ]]; then - echo -n "○ configuring gpu support " - envsubst < manifests/nvidia_aws.yaml | kubectl apply -f - >/dev/null - envsubst < manifests/prometheus-dcgm-exporter.yaml | kubectl apply -f - >/dev/null - echo "✓" - fi + echo -n "○ configuring gpu support (for the nodegroups that may require it)" + envsubst < manifests/nvidia_aws.yaml | kubectl apply -f - >/dev/null + NVIDIA_COM_GPU_VALUE=true envsubst < manifests/prometheus-dcgm-exporter.yaml | kubectl apply -f - >/dev/null + echo "✓" - if [[ "$CORTEX_INSTANCE_TYPE" == inf* ]]; then - echo -n "○ configuring inf support " - envsubst < manifests/inferentia.yaml | kubectl apply -f - >/dev/null - echo "✓" - fi + echo -n "○ configuring inf support (for the nodegroups that may require it)" + envsubst < manifests/inferentia.yaml | kubectl apply -f - >/dev/null + echo "✓" restart_operator @@ -130,12 +126,10 @@ function cluster_up_gcp() { setup_grafana echo "✓" - if [ -n "$CORTEX_ACCELERATOR_TYPE" ]; then - echo -n "○ configuring gpu support " - envsubst < manifests/nvidia_gcp.yaml | kubectl apply -f - >/dev/null - envsubst < manifests/prometheus-dcgm-exporter.yaml | kubectl apply -f - >/dev/null - echo "✓" - fi + echo -n "○ configuring gpu support (for the nodepools that may require it)" + envsubst < manifests/nvidia_gcp.yaml | kubectl apply -f - >/dev/null + NVIDIA_COM_GPU_VALUE=present envsubst < manifests/prometheus-dcgm-exporter.yaml | kubectl apply -f - >/dev/null + echo "✓" restart_operator @@ -334,85 +328,52 @@ function restart_operator() { } function resize_nodegroup() { - # check for change in min/max instances - asg_on_demand_info=$(aws autoscaling describe-auto-scaling-groups --region $CORTEX_REGION --query "AutoScalingGroups[?contains(Tags[?Key==\`alpha.eksctl.io/cluster-name\`].Value, \`$CORTEX_CLUSTER_NAME\`)]|[?contains(Tags[?Key==\`alpha.eksctl.io/nodegroup-name\`].Value, \`ng-cortex-worker-on-demand\`)]") - asg_on_demand_length=$(echo "$asg_on_demand_info" | jq -r 'length') - asg_on_demand_name="" - if (( "$asg_on_demand_length" > "0" )); then - asg_on_demand_name=$(echo "$asg_on_demand_info" | jq -r 'first | .AutoScalingGroupName') - fi - - asg_spot_info=$(aws autoscaling describe-auto-scaling-groups --region $CORTEX_REGION --query "AutoScalingGroups[?contains(Tags[?Key==\`alpha.eksctl.io/cluster-name\`].Value, \`$CORTEX_CLUSTER_NAME\`)]|[?contains(Tags[?Key==\`alpha.eksctl.io/nodegroup-name\`].Value, \`ng-cortex-worker-spot\`)]") - asg_spot_length=$(echo "$asg_spot_info" | jq -r 'length') - asg_spot_name="" - if (( "$asg_spot_length" > "0" )); then - asg_spot_name=$(echo "$asg_spot_info" | jq -r 'first | .AutoScalingGroupName') - fi - - if [[ -z "$asg_spot_name" ]] && [[ -z "$asg_on_demand_name" ]]; then - echo "error: unable to find valid autoscaling groups" - exit 1 - fi + eksctl get nodegroup --cluster=$CORTEX_CLUSTER_NAME --region=$CORTEX_REGION -o json > nodegroups.json + ng_len=$(cat nodegroups.json | jq -r length) + num_resizes=0 - if [[ -z $asg_spot_name ]]; then - asg_min_size=$(echo "$asg_on_demand_info" | jq -r 'first | .MinSize') - asg_max_size=$(echo "$asg_on_demand_info" | jq -r 'first | .MaxSize') - if [ "$asg_min_size" = "" ] || [ "$asg_min_size" = "null" ] || [ "$asg_max_size" = "" ] || [ "$asg_max_size" = "null" ]; then - echo -e "unable to find on-demand autoscaling group size from info:\n$asg_on_demand_info" - exit 1 - fi - else - asg_min_size=$(echo "$asg_spot_info" | jq -r 'first | .MinSize') - asg_max_size=$(echo "$asg_spot_info" | jq -r 'first | .MaxSize') - if [ "$asg_min_size" = "" ] || [ "$asg_min_size" = "null" ] || [ "$asg_max_size" = "" ] || [ "$asg_max_size" = "null" ]; then - echo -e "unable to find spot autoscaling group size from info:\n$asg_spot_info" - exit 1 + for idx in $(seq 0 $(($ng_len-1))); do + stack_ng=$(cat nodegroups.json | jq -r .[$idx].Name) + if [ "$stack_ng" = "cx-operator" ]; then + continue fi - fi - asg_on_demand_resize_flags="" - asg_spot_resize_flags="" + config_ng=$(cat /in/cluster_${CORTEX_CLUSTER_NAME}_${CORTEX_REGION}.yaml | yq -r .node_groups[$idx].name) - if [ "$asg_min_size" != "$CORTEX_MIN_INSTANCES" ]; then - # only update min for on-demand nodegroup if it's not a backup - if [[ -n $asg_on_demand_name ]] && [[ "$CORTEX_SPOT_CONFIG_ON_DEMAND_BACKUP" != "True" ]]; then - asg_on_demand_resize_flags+=" --min-size=$CORTEX_MIN_INSTANCES" + desired=$(cat nodegroups.json | jq -r .[$idx].DesiredCapacity) + existing_min=$(cat nodegroups.json | jq -r .[$idx].MinSize) + existing_max=$(cat nodegroups.json | jq -r .[$idx].MaxSize) + updating_min=$(cat /in/cluster_${CORTEX_CLUSTER_NAME}_${CORTEX_REGION}.yaml | yq -r .node_groups[$idx].min_instances) + updating_max=$(cat /in/cluster_${CORTEX_CLUSTER_NAME}_${CORTEX_REGION}.yaml | yq -r .node_groups[$idx].max_instances) + if [ $updating_min = "null" ]; then + updating_min=1 fi - if [[ -n $asg_spot_name ]]; then - asg_spot_resize_flags+=" --min-size=$CORTEX_MIN_INSTANCES" + if [ $updating_max = "null" ]; then + updating_max=5 fi - fi - if [ "$asg_max_size" != "$CORTEX_MAX_INSTANCES" ]; then - if [[ -n $asg_on_demand_name ]]; then - asg_on_demand_resize_flags+=" --max-size=$CORTEX_MAX_INSTANCES" - fi - if [[ -n $asg_spot_name ]]; then - asg_spot_resize_flags+=" --max-size=$CORTEX_MAX_INSTANCES" + if [ "$existing_min" != "$updating_min" ] && [ "$existing_max" != "$updating_max" ]; then + echo "○ nodegroup $idx ($config_ng): updating min instances to $updating_min and max instances to $updating_max " + eksctl scale nodegroup --cluster=$CORTEX_CLUSTER_NAME --region=$CORTEX_REGION $stack_ng --nodes $desired --nodes-min $updating_min --nodes-max $updating_max + num_resizes=$(($num_resizes+1)) + echo + elif [ "$existing_min" != "$updating_min" ]; then + echo "○ nodegroup $idx ($config_ng): updating min instances to $updating_min " + eksctl scale nodegroup --cluster=$CORTEX_CLUSTER_NAME --region=$CORTEX_REGION $stack_ng --nodes $desired --nodes-min $updating_min + num_resizes=$(($num_resizes+1)) + echo + elif [ "$existing_max" != "$updating_max" ]; then + echo "○ nodegroup $idx ($config_ng): updating max instances to $updating_max " + eksctl scale nodegroup --cluster=$CORTEX_CLUSTER_NAME --region=$CORTEX_REGION $stack_ng --nodes $desired --nodes-max $updating_max + num_resizes=$(($num_resizes+1)) + echo fi - fi - - is_resizing="false" - if [ "$asg_min_size" != "$CORTEX_MIN_INSTANCES" ] && [ "$asg_max_size" != "$CORTEX_MAX_INSTANCES" ]; then - echo -n "○ updating min instances to $CORTEX_MIN_INSTANCES and max instances to $CORTEX_MAX_INSTANCES " - is_resizing="true" - elif [ "$asg_min_size" != "$CORTEX_MIN_INSTANCES" ]; then - echo -n "○ updating min instances to $CORTEX_MIN_INSTANCES " - is_resizing="true" - elif [ "$asg_max_size" != "$CORTEX_MAX_INSTANCES" ]; then - echo -n "○ updating max instances to $CORTEX_MAX_INSTANCES " - is_resizing="true" - fi - - if [ "$asg_on_demand_resize_flags" != "" ]; then - aws autoscaling update-auto-scaling-group --region $CORTEX_REGION --auto-scaling-group-name $asg_on_demand_name $asg_on_demand_resize_flags - fi - if [ "$asg_spot_resize_flags" != "" ]; then - aws autoscaling update-auto-scaling-group --region $CORTEX_REGION --auto-scaling-group-name $asg_spot_name $asg_spot_resize_flags - fi + done + rm nodegroups.json - if [ "$is_resizing" == "true" ]; then - echo "✓" + if [ "$num_resizes" -eq "0" ]; then + echo "no changes to node group sizes detected in the cluster config" + exit 0 fi } diff --git a/manager/manifests/cluster-autoscaler.yaml.j2 b/manager/manifests/cluster-autoscaler.yaml.j2 index db6cdaba8c..8c96b663fd 100644 --- a/manager/manifests/cluster-autoscaler.yaml.j2 +++ b/manager/manifests/cluster-autoscaler.yaml.j2 @@ -131,7 +131,6 @@ subjects: name: cluster-autoscaler namespace: kube-system --- -{% if (config.get('spot_config') and config['spot_config'].get('on_demand_backup', false)) or config.get('on_demand_backup') %} apiVersion: v1 kind: ConfigMap metadata: @@ -139,19 +138,26 @@ metadata: namespace: kube-system data: priorities: |- - {% if config.get('spot_config') %} - 10: - - .*ng-cortex-worker-on-demand.* - 50: - - .*ng-cortex-worker-spot.* + {% if config['provider'] == 'aws' %} + {% for ng in config['node_groups']|reverse %} + {{ (loop.index0+1) * 10 }}: + {% if ng['spot'] %} + - .*{{ 'cx-ws-' + ng['name'] }}.* {% else %} - 10: - - .*ng-cortex-wk-on-dmd.* - 50: - - .*ng-cortex-wk-preemp.* + - .*{{ 'cx-wd-' + ng['name'] }}.* {% endif %} + {% endfor %} + {% else %} + {% for np in config['node_pools']|reverse %} + {{ (loop.index0+1) * 10 }}: + {% if np['preemptible'] %} + - .*{{ 'cx-ws-' + np['name'] }}.* + {% else %} + - .*{{ 'cx-wd-' + np['name'] }}.* + {% endif %} + {% endfor %} + {% endif %} --- -{% endif %} apiVersion: apps/v1 kind: Deployment metadata: @@ -171,7 +177,7 @@ spec: spec: serviceAccountName: cluster-autoscaler containers: - - image: {{ config["image_cluster_autoscaler"] }} + - image: {{ config['image_cluster_autoscaler'] }} name: cluster-autoscaler resources: limits: @@ -184,26 +190,27 @@ spec: - ./cluster-autoscaler - --v=4 - --stderrthreshold=info - {% if config["provider"] == "aws" %} + {% if config['provider'] == 'aws' %} - --cloud-provider=aws {% else %} - --cloud-provider=gce {% endif %} - --skip-nodes-with-local-storage=false - {% if (config.get('spot_config') and config['spot_config'].get('on_demand_backup', false)) or config.get('on_demand_backup') %} - --expander=priority - {% else %} - - --expander=least-waste - {% endif %} - - --max-nodes-total={{ config['max_instances'] + 2 }} - --max-total-unready-percentage=5 - --ok-total-unready-count=30 - --max-node-provision-time=5m - --scan-interval=20s - {% if config["provider"] == "aws" %} + {% if config['provider'] == 'aws' %} - --node-group-auto-discovery=asg:tag=k8s.io/cluster-autoscaler/enabled,k8s.io/cluster-autoscaler/{{ config['cluster_name'] }} {% else %} - - --node-group-auto-discovery=mig:namePrefix=gke-{{ config['cluster_name'] }}-ng-cortex-wk,min={{ config["min_instances"] }},max={{ config["max_instances"] }} + {% for np in config['node_pools'] %} + {% if np['preemptile'] %} + - --node-group-auto-discovery=mig:namePrefix=gke-{{ config['cluster_name'] }}-cx-ws-{{ np['name'] }},min={{ np['min_instances'] }},max={{ np['max_instances'] }} + {% else %} + - --node-group-auto-discovery=mig:namePrefix=gke-{{ config['cluster_name'] }}-cx-wd-{{ np['name'] }},min={{ np['min_instances'] }},max={{ np['max_instances'] }} + {% endif %} + {% endfor %} {% endif %} volumeMounts: - name: ssl-certs @@ -213,7 +220,7 @@ spec: volumes: - name: ssl-certs hostPath: - {% if config["provider"] == "aws" %} + {% if config['provider'] == 'aws' %} path: "/etc/ssl/certs/ca-bundle.crt" {% else %} path: "/etc/ssl/certs/ca-certificates.crt" diff --git a/manager/manifests/inferentia.yaml b/manager/manifests/inferentia.yaml index c820e8828f..666ea78c0a 100644 --- a/manager/manifests/inferentia.yaml +++ b/manager/manifests/inferentia.yaml @@ -155,6 +155,7 @@ spec: memory: 100Mi nodeSelector: workload: "true" + aws.amazon.com/neuron: "true" volumes: - name: device-plugin hostPath: diff --git a/manager/manifests/nvidia_aws.yaml b/manager/manifests/nvidia_aws.yaml index b9fa5540a4..66a3f1a48e 100644 --- a/manager/manifests/nvidia_aws.yaml +++ b/manager/manifests/nvidia_aws.yaml @@ -71,6 +71,7 @@ spec: memory: 100Mi nodeSelector: workload: "true" + nvidia.com/gpu: "true" volumes: - name: device-plugin hostPath: diff --git a/manager/manifests/nvidia_gcp.yaml b/manager/manifests/nvidia_gcp.yaml index 0f2288c113..e4648c4cd9 100644 --- a/manager/manifests/nvidia_gcp.yaml +++ b/manager/manifests/nvidia_gcp.yaml @@ -51,6 +51,7 @@ spec: hostPID: true nodeSelector: workload: "true" + nvidia.com/gpu: "present" volumes: - name: dev hostPath: diff --git a/manager/manifests/prometheus-dcgm-exporter.yaml b/manager/manifests/prometheus-dcgm-exporter.yaml index 27ec085d6f..3db31338ce 100644 --- a/manager/manifests/prometheus-dcgm-exporter.yaml +++ b/manager/manifests/prometheus-dcgm-exporter.yaml @@ -107,6 +107,9 @@ spec: path: /home/kubernetes/bin/nvidia type: "" name: nvidia-install-dir-host + nodeSelector: + workload: "true" + nvidia.com/gpu: "$NVIDIA_COM_GPU_VALUE" --- apiVersion: monitoring.coreos.com/v1 kind: PodMonitor diff --git a/manager/refresh.sh b/manager/refresh.sh index da75db44fc..126591d65e 100755 --- a/manager/refresh.sh +++ b/manager/refresh.sh @@ -29,11 +29,4 @@ fi eksctl utils write-kubeconfig --cluster=$CORTEX_CLUSTER_NAME --region=$CORTEX_REGION | (grep -v "saved kubeconfig as" | grep -v "using region" | grep -v "eksctl version" || true) out=$(kubectl get pods 2>&1 || true); if [[ "$out" == *"must be logged in to the server"* ]]; then echo "error: your aws iam user does not have access to this cluster; to grant access, see https://docs.cortex.dev/v/${CORTEX_VERSION_MINOR}/"; exit 1; fi -kubectl get -n=default configmap cluster-config -o yaml >> cluster_configmap.yaml -python refresh_cluster_config.py cluster_configmap.yaml tmp_cluster_config.yaml - -kubectl -n=default create configmap 'cluster-config' \ - --from-file='cluster.yaml'=tmp_cluster_config.yaml \ - -o yaml --dry-run=client | kubectl apply -f - >/dev/null - -cp tmp_cluster_config.yaml $cluster_config_out_path +kubectl get -n=default configmap cluster-config -o json | jq -r '.data."cluster.yaml"' >> $cluster_config_out_path diff --git a/manager/requirements.txt b/manager/requirements.txt index 5e5c7e5536..b9c8c79b97 100644 --- a/manager/requirements.txt +++ b/manager/requirements.txt @@ -1,3 +1,4 @@ boto3 jinja2 pyyaml +yq diff --git a/pkg/lib/aws/cloudformation.go b/pkg/lib/aws/cloudformation.go index 92a84c47d1..105c295895 100644 --- a/pkg/lib/aws/cloudformation.go +++ b/pkg/lib/aws/cloudformation.go @@ -25,11 +25,13 @@ import ( func (c *Client) ListEKSStacks(controlPlaneStackName string, nodegroupStackNames strset.Set) ([]*cloudformation.StackSummary, error) { var stackSummaries []*cloudformation.StackSummary stackSet := strset.Union(nodegroupStackNames, strset.New(controlPlaneStackName)) + err := c.CloudFormation().ListStacksPages( &cloudformation.ListStacksInput{}, func(listStackOutput *cloudformation.ListStacksOutput, lastPage bool) bool { for _, stackSummary := range listStackOutput.StackSummaries { - if stackSet.Has(*stackSummary.StackName) { + + if stackSet.HasWithPrefix(*stackSummary.StackName) { stackSummaries = append(stackSummaries, stackSummary) } diff --git a/pkg/lib/aws/errors.go b/pkg/lib/aws/errors.go index 3493c2df73..bc66c0c35c 100644 --- a/pkg/lib/aws/errors.go +++ b/pkg/lib/aws/errors.go @@ -151,11 +151,12 @@ func ErrorBucketNotFound(bucket string) error { }) } -func ErrorInsufficientInstanceQuota(instanceType string, lifecycle string, region string, requiredInstances int64, vCPUPerInstance int64, vCPUQuota int64, quotaCode string) error { +func ErrorInsufficientInstanceQuota(instanceTypes []string, lifecycle string, region string, requiredVCPUs int64, vCPUQuota int64, quotaCode string) error { url := fmt.Sprintf("https://%s.console.aws.amazon.com/servicequotas/home?region=%s#!/services/ec2/quotas/%s", region, region, quotaCode) + andInstanceTypes := s.StrsAnd(instanceTypes) return errors.WithStack(&errors.Error{ Kind: ErrInsufficientInstanceQuota, - Message: fmt.Sprintf("your cluster may require up to %d %s %s instances, but your AWS quota for %s %s instances in %s is only %d vCPU (there are %d vCPUs per %s instance); please reduce the maximum number of %s %s instances your cluster may use (e.g. by changing max_instances and/or spot_config if applicable), or request a quota increase to at least %d vCPU here: %s (if your request was recently approved, please allow ~30 minutes for AWS to reflect this change)", requiredInstances, lifecycle, instanceType, lifecycle, instanceType, region, vCPUQuota, vCPUPerInstance, instanceType, lifecycle, instanceType, requiredInstances*vCPUPerInstance, url), + Message: fmt.Sprintf("your cluster may require up to %d vCPU %s %s instances, but your AWS quota for %s %s instances in %s is only %d vCPU; please reduce the maximum number of %s %s instances your cluster may use (e.g. by changing max_instances and/or spot_config if applicable), or request a quota increase to at least %d vCPU here: %s (if your request was recently approved, please allow ~30 minutes for AWS to reflect this change)", requiredVCPUs, lifecycle, andInstanceTypes, lifecycle, andInstanceTypes, region, vCPUQuota, lifecycle, andInstanceTypes, requiredVCPUs, url), }) } diff --git a/pkg/lib/aws/servicequotas.go b/pkg/lib/aws/servicequotas.go index df9715b15a..750c27b61b 100644 --- a/pkg/lib/aws/servicequotas.go +++ b/pkg/lib/aws/servicequotas.go @@ -38,26 +38,65 @@ const ( _vpcQuotaCode = "L-F678F1CE" ) -func (c *Client) VerifyInstanceQuota(instanceType string, requiredOnDemandInstances int64, requiredSpotInstances int64) error { - if requiredOnDemandInstances == 0 && requiredSpotInstances == 0 { - return nil - } +type InstanceTypeRequests struct { + InstanceType string + RequiredOnDemandInstances int64 + RequiredSpotInstances int64 +} - instanceCategory := _instanceCategoryRegex.FindString(instanceType) +type instanceCategoryRequests struct { + InstanceTypes []string - // Allow the instance if we don't recognize the type - if !_knownInstanceCategories.Has(instanceCategory) { - return nil - } + InstanceCategory string + RequiredOnDemandCPUs int64 + RequiredSpotCPUs int64 + + OnDemandCPUQuota *int64 + OnDemandQuotaCode string + + SpotCPUQuota *int64 + SpotQuotaCode string +} + +func (c *Client) VerifyInstanceQuota(instances []InstanceTypeRequests) error { + instanceCategories := []instanceCategoryRequests{} + for _, instance := range instances { + if instance.RequiredOnDemandInstances == 0 && instance.RequiredSpotInstances == 0 { + continue + } + + instanceCategoryStr := _instanceCategoryRegex.FindString(instance.InstanceType) + // Allow the instance if we don't recognize the type + if !_knownInstanceCategories.Has(instanceCategoryStr) { + continue + } + if _standardInstanceCategories.Has(instanceCategoryStr) { + instanceCategoryStr = "standard" + } + + cpusPerInstance := InstanceMetadatas[c.Region][instance.InstanceType].CPU + + categoryFound := false + for idx, instanceCategory := range instanceCategories { + if instanceCategory.InstanceCategory == instanceCategoryStr { + instanceCategories[idx].InstanceTypes = append(instanceCategories[idx].InstanceTypes, instance.InstanceType) + instanceCategories[idx].RequiredOnDemandCPUs += instance.RequiredOnDemandInstances * cpusPerInstance.Value() + instanceCategories[idx].RequiredSpotCPUs += instance.RequiredSpotInstances * cpusPerInstance.Value() - if _standardInstanceCategories.Has(instanceCategory) { - instanceCategory = "standard" + categoryFound = true + } + } + + if !categoryFound { + instanceCategories = append(instanceCategories, instanceCategoryRequests{ + InstanceTypes: []string{instance.InstanceType}, + InstanceCategory: instanceCategoryStr, + RequiredOnDemandCPUs: instance.RequiredOnDemandInstances * cpusPerInstance.Value(), + RequiredSpotCPUs: instance.RequiredSpotInstances * cpusPerInstance.Value(), + }) + } } - var onDemandCPUQuota *int64 - var onDemandQuotaCode string - var spotCPUQuota *int64 - var spotQuotaCode string err := c.ServiceQuotas().ListServiceQuotasPages( &servicequotas.ListServiceQuotasInput{ ServiceCode: aws.String("ec2"), @@ -76,17 +115,15 @@ func (c *Client) VerifyInstanceQuota(instanceType string, requiredOnDemandInstan continue } - // quota is specified in number of vCPU permitted per family - if strings.ToLower(*metricClass) == instanceCategory+"/ondemand" { - onDemandCPUQuota = pointer.Int64(int64(*quota.Value)) - onDemandQuotaCode = *quota.QuotaCode - } else if strings.ToLower(*metricClass) == instanceCategory+"/spot" { - spotCPUQuota = pointer.Int64(int64(*quota.Value)) - spotQuotaCode = *quota.QuotaCode - } - - if onDemandCPUQuota != nil && spotCPUQuota != nil { - return false + for idx, instanceCategory := range instanceCategories { + // quota is specified in number of vCPU permitted per family + if strings.ToLower(*metricClass) == instanceCategory.InstanceCategory+"/ondemand" { + instanceCategories[idx].OnDemandCPUQuota = pointer.Int64(int64(*quota.Value)) + instanceCategories[idx].OnDemandQuotaCode = *quota.QuotaCode + } else if strings.ToLower(*metricClass) == instanceCategory.InstanceCategory+"/spot" { + instanceCategories[idx].SpotCPUQuota = pointer.Int64(int64(*quota.Value)) + instanceCategories[idx].SpotQuotaCode = *quota.QuotaCode + } } } return true @@ -96,16 +133,13 @@ func (c *Client) VerifyInstanceQuota(instanceType string, requiredOnDemandInstan return errors.WithStack(err) } - cpuPerInstance := InstanceMetadatas[c.Region][instanceType].CPU - requiredOnDemandCPU := requiredOnDemandInstances * cpuPerInstance.Value() - requiredSpotCPU := requiredSpotInstances * cpuPerInstance.Value() - - if onDemandCPUQuota != nil && *onDemandCPUQuota < requiredOnDemandCPU { - return ErrorInsufficientInstanceQuota(instanceType, "on-demand", c.Region, requiredOnDemandInstances, cpuPerInstance.Value(), *onDemandCPUQuota, onDemandQuotaCode) - } - - if spotCPUQuota != nil && *spotCPUQuota < requiredSpotCPU { - return ErrorInsufficientInstanceQuota(instanceType, "spot", c.Region, requiredSpotInstances, cpuPerInstance.Value(), *spotCPUQuota, spotQuotaCode) + for _, ic := range instanceCategories { + if ic.OnDemandCPUQuota != nil && *ic.OnDemandCPUQuota < ic.RequiredOnDemandCPUs { + return ErrorInsufficientInstanceQuota(strset.FromSlice(ic.InstanceTypes).Slice(), "on-demand", c.Region, ic.RequiredOnDemandCPUs, *ic.OnDemandCPUQuota, ic.OnDemandQuotaCode) + } + if ic.SpotCPUQuota != nil && *ic.SpotCPUQuota < ic.RequiredSpotCPUs { + return ErrorInsufficientInstanceQuota(strset.FromSlice(ic.InstanceTypes).Slice(), "spot", c.Region, ic.RequiredSpotCPUs, *ic.SpotCPUQuota, ic.SpotQuotaCode) + } } return nil diff --git a/pkg/lib/configreader/reader.go b/pkg/lib/configreader/reader.go index 285db8ce80..843e849cf5 100644 --- a/pkg/lib/configreader/reader.go +++ b/pkg/lib/configreader/reader.go @@ -975,7 +975,7 @@ func StructFromStringMap(dest interface{}, strMap map[string]string, v *StructVa } if !v.AllowExtraFields { - extraFields := slices.SubtractStrSlice(maps.StrMapKeys(strMap), allowedFields) + extraFields := slices.SubtractStrSlice(maps.StrMapKeysString(strMap), allowedFields) for _, extraField := range extraFields { allErrs = append(allErrs, ErrorUnsupportedKey(extraField)) } diff --git a/pkg/lib/configreader/string.go b/pkg/lib/configreader/string.go index 98a69f5026..f5de1d07c5 100644 --- a/pkg/lib/configreader/string.go +++ b/pkg/lib/configreader/string.go @@ -47,6 +47,7 @@ type StringValidation struct { DisallowTrailingWhitespace bool AlphaNumericDashDotUnderscoreOrEmpty bool AlphaNumericDashDotUnderscore bool + AlphaNumericDashUnderscoreOrEmpty bool AlphaNumericDashUnderscore bool AWSTag bool DNS1035 bool @@ -291,6 +292,12 @@ func ValidateStringVal(val string, v *StringValidation) error { } } + if v.AlphaNumericDashUnderscoreOrEmpty { + if !regex.IsAlphaNumericDashUnderscore(val) && val != "" { + return ErrorAlphaNumericDashUnderscore(val) + } + } + if v.AlphaNumericDashDotUnderscoreOrEmpty { if !regex.IsAlphaNumericDashDotUnderscore(val) && val != "" { return ErrorAlphaNumericDashDotUnderscore(val) diff --git a/pkg/lib/hash/hash.go b/pkg/lib/hash/hash.go index 09f6661cff..8d8f429d4f 100644 --- a/pkg/lib/hash/hash.go +++ b/pkg/lib/hash/hash.go @@ -19,6 +19,7 @@ package hash import ( "crypto/sha256" "encoding/hex" + "strings" "github.com/cortexlabs/cortex/pkg/lib/files" s "github.com/cortexlabs/cortex/pkg/lib/strings" @@ -36,6 +37,10 @@ func String(str string) string { return Bytes([]byte(str)) } +func Strings(strs ...string) string { + return String(strings.Join(strs, ",")) +} + func Any(obj interface{}) string { return String(s.Obj(obj)) } diff --git a/pkg/lib/maps/string.go b/pkg/lib/maps/string.go index 884ce2eb8b..e20c0715ee 100644 --- a/pkg/lib/maps/string.go +++ b/pkg/lib/maps/string.go @@ -16,7 +16,7 @@ limitations under the License. package maps -func StrMapKeys(myMap map[string]string) []string { +func StrMapKeysString(myMap map[string]string) []string { keys := make([]string, len(myMap)) i := 0 for key := range myMap { @@ -26,7 +26,7 @@ func StrMapKeys(myMap map[string]string) []string { return keys } -func StrMapValues(myMap map[string]string) []string { +func StrMapValuesString(myMap map[string]string) []string { values := make([]string, len(myMap)) i := 0 for _, value := range myMap { @@ -36,7 +36,7 @@ func StrMapValues(myMap map[string]string) []string { return values } -func MergeStrMaps(maps ...map[string]string) map[string]string { +func MergeStrMapsString(maps ...map[string]string) map[string]string { merged := map[string]string{} for _, m := range maps { for k, v := range m { @@ -46,7 +46,59 @@ func MergeStrMaps(maps ...map[string]string) map[string]string { return merged } -func StrMapsEqual(m1, m2 map[string]string) bool { +func StrMapsEqualString(m1, m2 map[string]string) bool { + if len(m1) != len(m2) { + return false + } + + if len(m1) == 0 && len(m2) == 0 { + return true + } + + if len(m1) == 0 || len(m2) == 0 { + return false + } + + for k, v1 := range m1 { + if v2, ok := m2[k]; !ok || v2 != v1 { + return false + } + } + + return true +} + +func StrMapKeysInt(myMap map[string]int) []string { + keys := make([]string, len(myMap)) + i := 0 + for key := range myMap { + keys[i] = key + i++ + } + return keys +} + +func StrMapValuesInt(myMap map[string]int) []int { + values := make([]int, len(myMap)) + i := 0 + for _, value := range myMap { + values[i] = value + i++ + } + return values +} + +func MergeStrMapsInt(maps ...map[string]int) map[string]int { + merged := map[string]int{} + for _, m := range maps { + for k, v := range m { + merged[k] = v + } + } + return merged +} + +func StrMapsEqualInt(m1, m2 map[string]int) bool { if len(m1) != len(m2) { return false } diff --git a/pkg/lib/sets/strset/strset.go b/pkg/lib/sets/strset/strset.go index 797c5d4232..4a862683c4 100644 --- a/pkg/lib/sets/strset/strset.go +++ b/pkg/lib/sets/strset/strset.go @@ -110,6 +110,19 @@ func (s Set) Has(items ...string) bool { return has } +// HasWithPrefix checks if at least one element of the set is the prefix of any of the passed items. +// It returns false if nothing is passed. +func (s Set) HasWithPrefix(items ...string) bool { + for _, prefix := range items { + for k := range s { + if strings.HasPrefix(prefix, k) { + return true + } + } + } + return false +} + // HasAny looks for the existence of any of the items passed. // It returns false if nothing is passed. // For multiple items it returns true if any of the items exist. diff --git a/pkg/lib/telemetry/telemetry.go b/pkg/lib/telemetry/telemetry.go index 5f17fbb48e..df67f8e044 100644 --- a/pkg/lib/telemetry/telemetry.go +++ b/pkg/lib/telemetry/telemetry.go @@ -174,12 +174,12 @@ func Error(err error, tags ...map[string]string) { return } - mergedTags := maps.MergeStrMaps(tags...) + mergedTags := maps.MergeStrMapsString(tags...) sentry.WithScope(func(scope *sentry.Scope) { e := EventFromException(err) scope.SetUser(sentry.User{ID: _config.UserID}) - scope.SetTags(maps.MergeStrMaps(_config.Properties, mergedTags)) + scope.SetTags(maps.MergeStrMapsString(_config.Properties, mergedTags)) scope.SetTags(map[string]string{"error_type": e.Exception[0].Type}) sentry.CaptureEvent(e) diff --git a/pkg/operator/config/config.go b/pkg/operator/config/config.go index d62b29701b..1e49c1417a 100644 --- a/pkg/operator/config/config.go +++ b/pkg/operator/config/config.go @@ -44,9 +44,9 @@ var ( Provider types.ProviderType OperatorMetadata *clusterconfig.OperatorMetadata - CoreConfig *clusterconfig.CoreConfig - managedConfig *clusterconfig.ManagedConfig - instanceMetadata *aws.InstanceMetadata + CoreConfig *clusterconfig.CoreConfig + managedConfig *clusterconfig.ManagedConfig + instancesMetadata []aws.InstanceMetadata GCPCoreConfig *clusterconfig.GCPCoreConfig gcpManagedConfig *clusterconfig.GCPManagedConfig @@ -66,9 +66,9 @@ func ManagedConfigOrNil() *clusterconfig.ManagedConfig { return nil } -func AWSInstanceMetadataOrNil() *aws.InstanceMetadata { +func AWSInstancesMetadata() []aws.InstanceMetadata { if CoreConfig.IsManaged { - return instanceMetadata + return instancesMetadata } return nil } @@ -109,8 +109,10 @@ func Init() error { if errors.HasError(errs) { return errors.FirstError(errs...) } - awsInstanceMetadata := aws.InstanceMetadatas[CoreConfig.Region][managedConfig.InstanceType] - instanceMetadata = &awsInstanceMetadata + + for _, instanceType := range managedConfig.GetAllInstanceTypes() { + instancesMetadata = append(instancesMetadata, aws.InstanceMetadatas[CoreConfig.Region][instanceType]) + } } AWS, err = aws.NewForRegion(CoreConfig.Region) diff --git a/pkg/operator/endpoints/info.go b/pkg/operator/endpoints/info.go index 54fd3c8b71..8d762a9e3d 100644 --- a/pkg/operator/endpoints/info.go +++ b/pkg/operator/endpoints/info.go @@ -50,7 +50,7 @@ func Info(w http.ResponseWriter, r *http.Request) { if config.IsManaged() { fullClusterConfig.Config.ManagedConfig = *config.ManagedConfigOrNil() - fullClusterConfig.InstanceMetadata = *config.AWSInstanceMetadataOrNil() + fullClusterConfig.InstancesMetadata = config.AWSInstancesMetadata() } response := schema.InfoResponse{ @@ -95,6 +95,7 @@ func getNodeInfos() ([]schema.NodeInfo, int, error) { for _, node := range nodes { instanceType := node.Labels["beta.kubernetes.io/instance-type"] + nodeGroupName := node.Labels["alpha.eksctl.io/nodegroup-name"] isSpot := strings.Contains(strings.ToLower(node.Labels["lifecycle"]), "spot") price := aws.InstanceMetadatas[config.CoreConfig.Region][instanceType].Price @@ -114,6 +115,7 @@ func getNodeInfos() ([]schema.NodeInfo, int, error) { nodeInfoMap[node.Name] = &schema.NodeInfo{ Name: node.Name, + NodeGroupName: nodeGroupName, InstanceType: instanceType, IsSpot: isSpot, Price: price, diff --git a/pkg/operator/operator/cron.go b/pkg/operator/operator/cron.go index eb84e2abf9..ad4d911fc9 100644 --- a/pkg/operator/operator/cron.go +++ b/pkg/operator/operator/cron.go @@ -98,10 +98,12 @@ func managedClusterTelemetry() (map[string]interface{}, error) { } instanceInfos := make(map[string]*instanceInfo) - var totalInstances int - managedConfig := config.ManagedConfigOrNil() + var totalInstances int + var totalInstancePrice float64 + var totalInstancePriceIfOnDemand float64 + for _, node := range nodes { if node.Labels["workload"] != "true" { continue @@ -138,6 +140,11 @@ func managedClusterTelemetry() (map[string]interface{}, error) { } } + ngName := node.Labels["alpha.eksctl.io/nodegroup-name"] + ebsPricePerVolume := getEBSPriceForNodeGroupInstance(managedConfig.NodeGroups, ngName) + onDemandPrice += ebsPricePerVolume + price += ebsPricePerVolume + gpuQty := node.Status.Capacity["nvidia.com/gpu"] infQty := node.Status.Capacity["aws.amazon.com/neuron"] @@ -161,18 +168,8 @@ func managedClusterTelemetry() (map[string]interface{}, error) { } instanceInfos[instanceInfosKey] = &info - } - - apiEBSPrice := aws.EBSMetadatas[config.CoreConfig.Region][managedConfig.InstanceVolumeType.String()].PriceGB * float64(managedConfig.InstanceVolumeSize) / 30 / 24 - if managedConfig.InstanceVolumeType.String() == "io1" && managedConfig.InstanceVolumeIOPS != nil { - apiEBSPrice += aws.EBSMetadatas[config.CoreConfig.Region][managedConfig.InstanceVolumeType.String()].PriceIOPS * float64(*managedConfig.InstanceVolumeIOPS) / 30 / 24 - } - - var totalInstancePrice float64 - var totalInstancePriceIfOnDemand float64 - for _, info := range instanceInfos { - totalInstancePrice += (info.Price + apiEBSPrice) * float64(info.Count) - totalInstancePriceIfOnDemand += (info.OnDemandPrice + apiEBSPrice) * float64(info.Count) + totalInstancePrice += info.Price + totalInstancePriceIfOnDemand += info.OnDemandPrice } fixedPrice := clusterFixedPriceAWS() @@ -189,6 +186,26 @@ func managedClusterTelemetry() (map[string]interface{}, error) { }, nil } +func getEBSPriceForNodeGroupInstance(ngs []*clusterconfig.NodeGroup, ngName string) float64 { + var ebsPrice float64 + for _, ng := range ngs { + var ngNamePrefix string + if ng.Spot { + ngNamePrefix = "cx-ws-" + } else { + ngNamePrefix = "cx-wd-" + } + if ng.Name == ngNamePrefix+ngName { + ebsPrice = aws.EBSMetadatas[config.CoreConfig.Region][ng.InstanceVolumeType.String()].PriceGB * float64(ng.InstanceVolumeSize) / 30 / 24 + if ng.InstanceVolumeType.String() == "io1" && ng.InstanceVolumeIOPS != nil { + ebsPrice += aws.EBSMetadatas[config.CoreConfig.Region][ng.InstanceVolumeType.String()].PriceIOPS * float64(*ng.InstanceVolumeIOPS) / 30 / 24 + } + break + } + } + return ebsPrice +} + func clusterFixedPriceAWS() float64 { eksPrice := aws.EKSPrices[config.CoreConfig.Region] operatorInstancePrice := aws.InstanceMetadatas[config.CoreConfig.Region]["t3.medium"].Price diff --git a/pkg/operator/operator/k8s.go b/pkg/operator/operator/k8s.go index 10ad7b99b1..57be71b540 100644 --- a/pkg/operator/operator/k8s.go +++ b/pkg/operator/operator/k8s.go @@ -1133,24 +1133,91 @@ func NodeSelectors() map[string]string { return nodeSelectors } -var Tolerations = []kcore.Toleration{ - { - Key: "workload", - Operator: kcore.TolerationOpEqual, - Value: "true", - Effect: kcore.TaintEffectNoSchedule, - }, - { - Key: "nvidia.com/gpu", - Operator: kcore.TolerationOpExists, - Effect: kcore.TaintEffectNoSchedule, - }, - { - Key: "aws.amazon.com/neuron", - Operator: kcore.TolerationOpEqual, - Value: "true", - Effect: kcore.TaintEffectNoSchedule, - }, +func GenerateResourceTolerations() []kcore.Toleration { + tolerations := []kcore.Toleration{ + { + Key: "workload", + Operator: kcore.TolerationOpEqual, + Value: "true", + Effect: kcore.TaintEffectNoSchedule, + }, + { + Key: "nvidia.com/gpu", + Operator: kcore.TolerationOpExists, + Effect: kcore.TaintEffectNoSchedule, + }, + { + Key: "aws.amazon.com/neuron", + Operator: kcore.TolerationOpEqual, + Value: "true", + Effect: kcore.TaintEffectNoSchedule, + }, + } + + return tolerations +} + +func GeneratePreferredNodeAffinities() []kcore.PreferredSchedulingTerm { + affinities := []kcore.PreferredSchedulingTerm{} + + if config.Provider == types.AWSProviderType { + clusterConfig := config.ManagedConfigOrNil() + if clusterConfig == nil { + return nil + } + + numNodeGroups := len(clusterConfig.NodeGroups) + for idx, nodeGroup := range clusterConfig.NodeGroups { + var nodeGroupPrefix string + if nodeGroup.Spot { + nodeGroupPrefix = "cx-ws-" + } else { + nodeGroupPrefix = "cx-wd-" + } + affinities = append(affinities, kcore.PreferredSchedulingTerm{ + Weight: int32(100 * (1 - float64(idx)/float64(numNodeGroups))), + Preference: kcore.NodeSelectorTerm{ + MatchExpressions: []kcore.NodeSelectorRequirement{ + { + Key: "alpha.eksctl.io/nodegroup-name", + Operator: kcore.NodeSelectorOpIn, + Values: []string{nodeGroupPrefix + nodeGroup.Name}, + }, + }, + }, + }) + } + } + if config.Provider == types.GCPProviderType { + clusterConfig := config.GCPManagedConfigOrNil() + if clusterConfig == nil { + return nil + } + + numNodePools := len(clusterConfig.NodePools) + for idx, nodePool := range clusterConfig.NodePools { + var nodePoolPrefix string + if nodePool.Preemptible { + nodePoolPrefix = "cx-ws-" + } else { + nodePoolPrefix = "cx-wd-" + } + affinities = append(affinities, kcore.PreferredSchedulingTerm{ + Weight: int32(100 * (1 - float64(idx)/float64(numNodePools))), + Preference: kcore.NodeSelectorTerm{ + MatchExpressions: []kcore.NodeSelectorRequirement{ + { + Key: "cloud.google.com/gke-nodepool", + Operator: kcore.NodeSelectorOpIn, + Values: []string{nodePoolPrefix + nodePool.Name}, + }, + }, + }, + }) + } + } + + return affinities } func K8sName(apiName string) string { diff --git a/pkg/operator/operator/memory_capacity.go b/pkg/operator/operator/memory_capacity.go index 1c2d48f0c8..040fe5af6d 100644 --- a/pkg/operator/operator/memory_capacity.go +++ b/pkg/operator/operator/memory_capacity.go @@ -17,21 +17,19 @@ limitations under the License. package operator import ( - "math" - "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/k8s" + "github.com/cortexlabs/cortex/pkg/lib/slices" "github.com/cortexlabs/cortex/pkg/operator/config" - "github.com/cortexlabs/cortex/pkg/types" kresource "k8s.io/apimachinery/pkg/api/resource" kmeta "k8s.io/apimachinery/pkg/apis/meta/v1" klabels "k8s.io/apimachinery/pkg/labels" ) const _memConfigMapName = "cortex-instance-memory" -const _memConfigMapKey = "capacity" +const _configKeyPrefix = "memory-capacity-" -func getMemoryCapacityFromNodes() (*kresource.Quantity, error) { +func getMemoryCapacityFromNodes(primaryInstances []string) (map[string]*kresource.Quantity, error) { opts := kmeta.ListOptions{ LabelSelector: klabels.SelectorFromSet(map[string]string{ "workload": "true", @@ -42,23 +40,40 @@ func getMemoryCapacityFromNodes() (*kresource.Quantity, error) { return nil, err } - var minMem *kresource.Quantity + minMemMap := map[string]*kresource.Quantity{} + for _, primaryInstance := range primaryInstances { + minMemMap[primaryInstance] = nil + } + for _, node := range nodes { + isPrimaryInstance := false + var primaryInstanceType string + for k, v := range node.Labels { + if k == "beta.kubernetes.io/instance-type" && slices.HasString(primaryInstances, v) { + isPrimaryInstance = true + primaryInstanceType = v + break + } + } + if !isPrimaryInstance { + continue + } + curMem := node.Status.Capacity.Memory() if curMem == nil || curMem.IsZero() { continue } - if minMem == nil || minMem.Cmp(*curMem) < 0 { - minMem = curMem + if minMemMap[primaryInstanceType] == nil || minMemMap[primaryInstanceType].Cmp(*curMem) > 0 { + minMemMap[primaryInstanceType] = curMem } } - return minMem, nil + return minMemMap, nil } -func getMemoryCapacityFromConfigMap() (*kresource.Quantity, error) { +func getMemoryCapacityFromConfigMap() (map[string]*kresource.Quantity, error) { configMapData, err := config.K8s.GetConfigMapData(_memConfigMapName) if err != nil { return nil, err @@ -68,63 +83,81 @@ func getMemoryCapacityFromConfigMap() (*kresource.Quantity, error) { return nil, nil } - memoryUserStr := configMapData[_memConfigMapKey] - mem, err := kresource.ParseQuantity(memoryUserStr) - if err != nil { - return nil, err - } - if mem.IsZero() { - return nil, nil + memoryCapacitiesMap := map[string]*kresource.Quantity{} + for k := range configMapData { + memoryUserStr := configMapData[k] + mem, err := kresource.ParseQuantity(memoryUserStr) + if err != nil { + return nil, err + } + instanceType := k[len(_configKeyPrefix):] + if mem.IsZero() { + memoryCapacitiesMap[instanceType] = nil + } else { + memoryCapacitiesMap[instanceType] = &mem + } } - return &mem, nil + + return memoryCapacitiesMap, nil } -func UpdateMemoryCapacityConfigMap() (kresource.Quantity, error) { +func UpdateMemoryCapacityConfigMap() (map[string]kresource.Quantity, error) { if !config.IsManaged() { - return kresource.Quantity{}, nil + return nil, nil } - minMem := *kresource.NewQuantity(math.MaxInt64, kresource.DecimalSI) + instancesMetadata := config.AWSInstancesMetadata() + if len(instancesMetadata) == 0 { + return nil, errors.ErrorUnexpected("unable to find instances metadata; likely because this is not a cortex managed cluster") + } + primaryInstances := []string{} - if config.Provider == types.AWSProviderType { - instanceMetadata := config.AWSInstanceMetadataOrNil() - if instanceMetadata == nil { - return kresource.Quantity{}, errors.ErrorUnexpected("unable to find instance metadata; likely because this is not a cortex managed cluster") - } - minMem = instanceMetadata.Memory + minMemMap := map[string]kresource.Quantity{} + for _, instanceMetadata := range instancesMetadata { + minMemMap[instanceMetadata.Type] = instanceMetadata.Memory + primaryInstances = append(primaryInstances, instanceMetadata.Type) } - nodeMemCapacity, err := getMemoryCapacityFromNodes() + nodeMemCapacityMap, err := getMemoryCapacityFromNodes(primaryInstances) if err != nil { - return kresource.Quantity{}, err + return nil, err } - previousMinMem, err := getMemoryCapacityFromConfigMap() + previousMinMemMap, err := getMemoryCapacityFromConfigMap() if err != nil { - return kresource.Quantity{}, err + return nil, err } - if nodeMemCapacity != nil && minMem.Cmp(*nodeMemCapacity) > 0 { - minMem = *nodeMemCapacity - } + configMapData := map[string]string{} + for _, primaryInstance := range primaryInstances { + minMem := minMemMap[primaryInstance] - if previousMinMem != nil && minMem.Cmp(*previousMinMem) > 0 { - minMem = *previousMinMem - } + if nodeMemCapacityMap[primaryInstance] != nil && minMem.Cmp(*nodeMemCapacityMap[primaryInstance]) > 0 { + minMem = *nodeMemCapacityMap[primaryInstance] + } - if previousMinMem == nil || minMem.Cmp(*previousMinMem) < 0 { - configMap := k8s.ConfigMap(&k8s.ConfigMapSpec{ - Name: _memConfigMapName, - Data: map[string]string{ - _memConfigMapKey: minMem.String(), - }, - }) + if previousMinMemMap[primaryInstance] != nil && minMem.Cmp(*previousMinMemMap[primaryInstance]) > 0 { + minMem = *previousMinMemMap[primaryInstance] + } - _, err := config.K8s.ApplyConfigMap(configMap) - if err != nil { - return kresource.Quantity{}, err + if previousMinMemMap[primaryInstance] == nil || minMem.Cmp(*previousMinMemMap[primaryInstance]) < 0 { + configMapData[_configKeyPrefix+primaryInstance] = minMem.String() + } else { + configMapData[_configKeyPrefix+primaryInstance] = previousMinMemMap[primaryInstance].String() } + + minMemMap[primaryInstance] = minMem + } + + configMap := k8s.ConfigMap(&k8s.ConfigMapSpec{ + Name: _memConfigMapName, + Data: configMapData, + }) + + _, err = config.K8s.ApplyConfigMap(configMap) + if err != nil { + return nil, err } - return minMem, nil + return minMemMap, nil } diff --git a/pkg/operator/resources/job/batchapi/k8s_specs.go b/pkg/operator/resources/job/batchapi/k8s_specs.go index 4f0cb82ca3..84053f19f1 100644 --- a/pkg/operator/resources/job/batchapi/k8s_specs.go +++ b/pkg/operator/resources/job/batchapi/k8s_specs.go @@ -88,9 +88,14 @@ func pythonPredictorJobSpec(api *spec.API, job *spec.BatchJob) (*kbatch.Job, err InitContainers: []kcore.Container{ operator.InitContainer(api), }, - Containers: containers, - NodeSelector: operator.NodeSelectors(), - Tolerations: operator.Tolerations, + Containers: containers, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, Volumes: volumes, ServiceAccountName: operator.ServiceAccountName, }, @@ -138,9 +143,14 @@ func tensorFlowPredictorJobSpec(api *spec.API, job *spec.BatchJob) (*kbatch.Job, InitContainers: []kcore.Container{ operator.InitContainer(api), }, - Containers: containers, - NodeSelector: operator.NodeSelectors(), - Tolerations: operator.Tolerations, + Containers: containers, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, Volumes: volumes, ServiceAccountName: operator.ServiceAccountName, }, @@ -189,9 +199,14 @@ func onnxPredictorJobSpec(api *spec.API, job *spec.BatchJob) (*kbatch.Job, error InitContainers: []kcore.Container{ operator.InitContainer(api), }, - Containers: containers, - NodeSelector: operator.NodeSelectors(), - Tolerations: operator.Tolerations, + Containers: containers, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, Volumes: volumes, ServiceAccountName: operator.ServiceAccountName, }, diff --git a/pkg/operator/resources/job/taskapi/k8s_specs.go b/pkg/operator/resources/job/taskapi/k8s_specs.go index b8e0dc0848..4a18c25c8a 100644 --- a/pkg/operator/resources/job/taskapi/k8s_specs.go +++ b/pkg/operator/resources/job/taskapi/k8s_specs.go @@ -116,9 +116,14 @@ func k8sJobSpec(api *spec.API, job *spec.TaskJob) (*kbatch.Job, error) { InitContainers: []kcore.Container{ operator.TaskInitContainer(api), }, - Containers: containers, - NodeSelector: operator.NodeSelectors(), - Tolerations: operator.Tolerations, + Containers: containers, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, Volumes: volumes, ServiceAccountName: operator.ServiceAccountName, }, diff --git a/pkg/operator/resources/realtimeapi/k8s_specs.go b/pkg/operator/resources/realtimeapi/k8s_specs.go index 141825d345..ac64ed03ae 100644 --- a/pkg/operator/resources/realtimeapi/k8s_specs.go +++ b/pkg/operator/resources/realtimeapi/k8s_specs.go @@ -82,9 +82,14 @@ func tensorflowAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.D InitContainers: []kcore.Container{ operator.InitContainer(api), }, - Containers: containers, - NodeSelector: operator.NodeSelectors(), - Tolerations: operator.Tolerations, + Containers: containers, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, Volumes: volumes, ServiceAccountName: operator.ServiceAccountName, }, @@ -132,9 +137,14 @@ func pythonAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.Deplo InitContainers: []kcore.Container{ operator.InitContainer(api), }, - Containers: containers, - NodeSelector: operator.NodeSelectors(), - Tolerations: operator.Tolerations, + Containers: containers, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, Volumes: volumes, ServiceAccountName: operator.ServiceAccountName, }, @@ -183,9 +193,14 @@ func onnxAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.Deploym TerminationGracePeriodSeconds: pointer.Int64(_terminationGracePeriodSeconds), Containers: containers, NodeSelector: operator.NodeSelectors(), - Tolerations: operator.Tolerations, - Volumes: volumes, - ServiceAccountName: operator.ServiceAccountName, + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, + Volumes: volumes, + ServiceAccountName: operator.ServiceAccountName, }, }, }) diff --git a/pkg/operator/resources/validations.go b/pkg/operator/resources/validations.go index d71eb6bd61..c479373160 100644 --- a/pkg/operator/resources/validations.go +++ b/pkg/operator/resources/validations.go @@ -23,7 +23,6 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/files" "github.com/cortexlabs/cortex/pkg/lib/k8s" - "github.com/cortexlabs/cortex/pkg/lib/parallel" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/config" @@ -116,7 +115,7 @@ func ValidateClusterAPIs(apis []userconfig.API, projectFiles spec.ProjectFiles) } if config.IsManaged() && config.Provider == types.AWSProviderType { - maxMem, err := operator.UpdateMemoryCapacityConfigMap() + maxMemMap, err := operator.UpdateMemoryCapacityConfigMap() if err != nil { return err } @@ -124,7 +123,7 @@ func ValidateClusterAPIs(apis []userconfig.API, projectFiles spec.ProjectFiles) for i := range apis { api := &apis[i] if api.Kind == userconfig.RealtimeAPIKind || api.Kind == userconfig.BatchAPIKind || api.Kind == userconfig.TaskAPIKind { - if err := awsManagedValidateK8sCompute(api.Compute, maxMem); err != nil { + if err := awsManagedValidateK8sCompute(api.Compute, maxMemMap); err != nil { return err } } @@ -174,45 +173,60 @@ var _nvidiaDCGMExporterMemReserve = kresource.MustParse("50Mi") var _inferentiaCPUReserve = kresource.MustParse("100m") var _inferentiaMemReserve = kresource.MustParse("100Mi") -func awsManagedValidateK8sCompute(compute *userconfig.Compute, maxMem kresource.Quantity) error { - instanceMetadata := config.AWSInstanceMetadataOrNil() - if instanceMetadata == nil { +func awsManagedValidateK8sCompute(compute *userconfig.Compute, maxMemMap map[string]kresource.Quantity) error { + instancesMetadata := config.AWSInstancesMetadata() + if len(instancesMetadata) == 0 { return errors.ErrorUnexpected("unable to find instance metadata; likely because this is not a cortex managed cluster") } - maxMem.Sub(_cortexMemReserve) - - maxCPU := instanceMetadata.CPU - maxCPU.Sub(_cortexCPUReserve) + allErrors := []error{} + successfulLoops := 0 + for _, instanceMetadata := range instancesMetadata { + maxMemLoop := maxMemMap[instanceMetadata.Type] + maxMemLoop.Sub(_cortexMemReserve) + + maxCPU := instanceMetadata.CPU + maxCPU.Sub(_cortexCPUReserve) + + maxGPU := instanceMetadata.GPU + if maxGPU > 0 { + // Reserve resources for nvidia device plugin daemonset + maxCPU.Sub(_nvidiaCPUReserve) + maxMemLoop.Sub(_nvidiaMemReserve) + // Reserve resources for nvidia dcgm prometheus exporter + maxCPU.Sub(_nvidiaDCGMExporterCPUReserve) + maxMemLoop.Sub(_nvidiaDCGMExporterMemReserve) + } - maxGPU := instanceMetadata.GPU - if maxGPU > 0 { - // Reserve resources for nvidia device plugin daemonset - maxCPU.Sub(_nvidiaCPUReserve) - maxMem.Sub(_nvidiaMemReserve) - // Reserve resources for nvidia dcgm prometheus exporter - maxCPU.Sub(_nvidiaDCGMExporterCPUReserve) - maxMem.Sub(_nvidiaDCGMExporterMemReserve) - } + maxInf := instanceMetadata.Inf + if maxInf > 0 { + // Reserve resources for inferentia device plugin daemonset + maxCPU.Sub(_inferentiaCPUReserve) + maxMemLoop.Sub(_inferentiaMemReserve) + } - maxInf := instanceMetadata.Inf - if maxInf > 0 { - // Reserve resources for inferentia device plugin daemonset - maxCPU.Sub(_inferentiaCPUReserve) - maxMem.Sub(_inferentiaMemReserve) + loopErrors := []error{} + if compute.CPU != nil && maxCPU.Cmp(compute.CPU.Quantity) < 0 { + loopErrors = append(loopErrors, ErrorNoAvailableNodeComputeLimit("CPU", compute.CPU.String(), maxCPU.String())) + } + if compute.Mem != nil && maxMemLoop.Cmp(compute.Mem.Quantity) < 0 { + loopErrors = append(loopErrors, ErrorNoAvailableNodeComputeLimit("memory", compute.Mem.String(), maxMemLoop.String())) + } + if compute.GPU > maxGPU { + loopErrors = append(loopErrors, ErrorNoAvailableNodeComputeLimit("GPU", fmt.Sprintf("%d", compute.GPU), fmt.Sprintf("%d", maxGPU))) + } + if compute.Inf > maxInf { + loopErrors = append(loopErrors, ErrorNoAvailableNodeComputeLimit("Inf", fmt.Sprintf("%d", compute.Inf), fmt.Sprintf("%d", maxInf))) + } + if errors.HasError(loopErrors) { + allErrors = append(allErrors, errors.FirstError(loopErrors...)) + } else { + successfulLoops++ + } } - if compute.CPU != nil && maxCPU.Cmp(compute.CPU.Quantity) < 0 { - return ErrorNoAvailableNodeComputeLimit("CPU", compute.CPU.String(), maxCPU.String()) - } - if compute.Mem != nil && maxMem.Cmp(compute.Mem.Quantity) < 0 { - return ErrorNoAvailableNodeComputeLimit("memory", compute.Mem.String(), maxMem.String()) - } - if compute.GPU > maxGPU { - return ErrorNoAvailableNodeComputeLimit("GPU", fmt.Sprintf("%d", compute.GPU), fmt.Sprintf("%d", maxGPU)) - } - if compute.Inf > maxInf { - return ErrorNoAvailableNodeComputeLimit("Inf", fmt.Sprintf("%d", compute.Inf), fmt.Sprintf("%d", maxInf)) + if successfulLoops == 0 { + return errors.FirstError(allErrors...) } return nil @@ -252,26 +266,6 @@ func findDuplicateEndpoints(apis []userconfig.API) []userconfig.API { return nil } -func getValidationK8sResources() ([]istioclientnetworking.VirtualService, kresource.Quantity, error) { - var virtualServices []istioclientnetworking.VirtualService - var maxMem kresource.Quantity - - err := parallel.RunFirstErr( - func() error { - var err error - virtualServices, err = config.K8s.ListVirtualServices(nil) - return err - }, - func() error { - var err error - maxMem, err = operator.UpdateMemoryCapacityConfigMap() - return err - }, - ) - - return virtualServices, maxMem, err -} - // InclusiveFilterAPIsByKind includes only provided Kinds func InclusiveFilterAPIsByKind(apis []userconfig.API, kindsToInclude ...userconfig.Kind) []userconfig.API { kindsToIncludeSet := strset.New() diff --git a/pkg/operator/schema/schema.go b/pkg/operator/schema/schema.go index 5d3386606a..38cb4705e8 100644 --- a/pkg/operator/schema/schema.go +++ b/pkg/operator/schema/schema.go @@ -38,6 +38,7 @@ type InfoGCPResponse struct { type NodeInfo struct { Name string `json:"name"` + NodeGroupName string `json:"nodegroup_name"` InstanceType string `json:"instance_type"` IsSpot bool `json:"is_spot"` Price float64 `json:"price"` @@ -129,3 +130,13 @@ type APIVersion struct { type VerifyCortexResponse struct { Provider types.ProviderType `json:"provider"` } + +func (ir InfoResponse) GetNodesWithNodeGroupName(ngName string) []NodeInfo { + nodesInfo := []NodeInfo{} + for _, nodeInfo := range ir.NodeInfos { + if nodeInfo.NodeGroupName == ngName { + nodesInfo = append(nodesInfo, nodeInfo) + } + } + return nodesInfo +} diff --git a/pkg/types/clusterconfig/availability_zones.go b/pkg/types/clusterconfig/availability_zones.go index fe73523958..2ce8f43e83 100644 --- a/pkg/types/clusterconfig/availability_zones.go +++ b/pkg/types/clusterconfig/availability_zones.go @@ -40,8 +40,14 @@ func (cc *Config) setAvailabilityZones(awsClient *aws.Client) error { return nil } -func (cc *Config) setDefaultAvailabilityZones(awsClient *aws.Client, extraInstances ...string) error { - zones, err := awsClient.ListSupportedAvailabilityZones(cc.InstanceType, extraInstances...) +func (cc *Config) setDefaultAvailabilityZones(awsClient *aws.Client) error { + instanceTypes := strset.New() + for _, ng := range cc.NodeGroups { + instanceTypes.Add(ng.InstanceType) + } + instanceTypesSlice := instanceTypes.Slice() + + zones, err := awsClient.ListSupportedAvailabilityZones(instanceTypesSlice[0], instanceTypesSlice[1:]...) if err != nil { // Try again without checking instance types zones, err = awsClient.ListAvailabilityZonesInRegion() @@ -53,7 +59,7 @@ func (cc *Config) setDefaultAvailabilityZones(awsClient *aws.Client, extraInstan zones.Subtract(_azBlacklist) if len(zones) < 2 { - return ErrorNotEnoughDefaultSupportedZones(awsClient.Region, zones, cc.InstanceType, extraInstances...) + return ErrorNotEnoughDefaultSupportedZones(awsClient.Region, zones, instanceTypesSlice[0], instanceTypesSlice[1:]...) } // See https://github.com/weaveworks/eksctl/blob/master/pkg/eks/api.go @@ -69,6 +75,12 @@ func (cc *Config) setDefaultAvailabilityZones(awsClient *aws.Client, extraInstan } func (cc *Config) validateUserAvailabilityZones(awsClient *aws.Client, extraInstances ...string) error { + instanceTypes := strset.New() + for _, ng := range cc.NodeGroups { + instanceTypes.Add(ng.InstanceType) + } + instanceTypesSlice := instanceTypes.Slice() + allZones, err := awsClient.ListAvailabilityZonesInRegion() if err != nil { return nil // Skip validation @@ -80,7 +92,7 @@ func (cc *Config) validateUserAvailabilityZones(awsClient *aws.Client, extraInst } } - supportedZones, err := awsClient.ListSupportedAvailabilityZones(cc.InstanceType, extraInstances...) + supportedZones, err := awsClient.ListSupportedAvailabilityZones(instanceTypesSlice[0], instanceTypesSlice[1:]...) if err != nil { // Skip validation instance-based validation supportedZones = strset.Difference(allZones, _azBlacklist) @@ -88,7 +100,7 @@ func (cc *Config) validateUserAvailabilityZones(awsClient *aws.Client, extraInst for _, userZone := range cc.AvailabilityZones { if !supportedZones.Has(userZone) { - return ErrorUnsupportedAvailabilityZone(userZone, cc.InstanceType, extraInstances...) + return ErrorUnsupportedAvailabilityZone(userZone, instanceTypesSlice[0], instanceTypesSlice[1:]...) } } diff --git a/pkg/types/clusterconfig/cluster_config.go b/pkg/types/clusterconfig/cluster_config.go index 1c76009e6e..6a63160f67 100644 --- a/pkg/types/clusterconfig/cluster_config.go +++ b/pkg/types/clusterconfig/cluster_config.go @@ -24,7 +24,8 @@ import ( ) const ( - ClusterNameTag = "cortex.dev/cluster-name" + ClusterNameTag = "cortex.dev/cluster-name" + MaxNodePoolsOrGroups = 100 ) var ( diff --git a/pkg/types/clusterconfig/cluster_config_aws.go b/pkg/types/clusterconfig/cluster_config_aws.go index 050eeac7c4..6a4473318a 100644 --- a/pkg/types/clusterconfig/cluster_config_aws.go +++ b/pkg/types/clusterconfig/cluster_config_aws.go @@ -17,6 +17,8 @@ limitations under the License. package clusterconfig import ( + "crypto/sha256" + "encoding/hex" "fmt" "io/ioutil" "math" @@ -34,9 +36,9 @@ import ( libmath "github.com/cortexlabs/cortex/pkg/lib/math" "github.com/cortexlabs/cortex/pkg/lib/pointer" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" - s "github.com/cortexlabs/cortex/pkg/lib/strings" - "github.com/cortexlabs/cortex/pkg/lib/table" + "github.com/cortexlabs/cortex/pkg/lib/slices" "github.com/cortexlabs/cortex/pkg/types" + "github.com/cortexlabs/yaml" ) const ( @@ -45,8 +47,10 @@ const ( ) var ( - _maxInstancePools = 20 - _cachedCNISupportedInstances *string + _maxNodeGroupLengthWithPrefix = 19 // node pool length name limit on GKE, using the same on AWS for consistency reasons + _maxNodeGroupLength = _maxNodeGroupLengthWithPrefix - len("cx-wd-") // or cx-ws- + _maxInstancePools = 20 + _cachedCNISupportedInstances *string // This regex is stricter than the actual S3 rules _strictS3BucketRegex = regexp.MustCompile(`^([a-z0-9])+(-[a-z0-9]+)*$`) _defaultIAMPolicies = []string{"arn:aws:iam::aws:policy/AmazonS3FullAccess"} @@ -87,15 +91,8 @@ type CoreConfig struct { } type ManagedConfig struct { - InstanceType string `json:"instance_type" yaml:"instance_type"` - MinInstances int64 `json:"min_instances" yaml:"min_instances"` - MaxInstances int64 `json:"max_instances" yaml:"max_instances"` - InstanceVolumeSize int64 `json:"instance_volume_size" yaml:"instance_volume_size"` - InstanceVolumeType VolumeType `json:"instance_volume_type" yaml:"instance_volume_type"` - InstanceVolumeIOPS *int64 `json:"instance_volume_iops" yaml:"instance_volume_iops"` + NodeGroups []*NodeGroup `json:"node_groups" yaml:"node_groups"` Tags map[string]string `json:"tags" yaml:"tags"` - Spot bool `json:"spot" yaml:"spot"` - SpotConfig *SpotConfig `json:"spot_config" yaml:"spot_config"` AvailabilityZones []string `json:"availability_zones" yaml:"availability_zones"` SSLCertificateARN *string `json:"ssl_certificate_arn,omitempty" yaml:"ssl_certificate_arn,omitempty"` IAMPolicyARNs []string `json:"iam_policy_arns" yaml:"iam_policy_arns"` @@ -108,13 +105,24 @@ type ManagedConfig struct { CortexPolicyARN string `json:"cortex_policy_arn" yaml:"cortex_policy_arn"` // this field is not user facing } +type NodeGroup struct { + Name string `json:"name" yaml:"name"` + InstanceType string `json:"instance_type" yaml:"instance_type"` + MinInstances int64 `json:"min_instances" yaml:"min_instances"` + MaxInstances int64 `json:"max_instances" yaml:"max_instances"` + InstanceVolumeSize int64 `json:"instance_volume_size" yaml:"instance_volume_size"` + InstanceVolumeType VolumeType `json:"instance_volume_type" yaml:"instance_volume_type"` + InstanceVolumeIOPS *int64 `json:"instance_volume_iops" yaml:"instance_volume_iops"` + Spot bool `json:"spot" yaml:"spot"` + SpotConfig *SpotConfig `json:"spot_config" yaml:"spot_config"` +} + type SpotConfig struct { InstanceDistribution []string `json:"instance_distribution" yaml:"instance_distribution"` OnDemandBaseCapacity *int64 `json:"on_demand_base_capacity" yaml:"on_demand_base_capacity"` OnDemandPercentageAboveBaseCapacity *int64 `json:"on_demand_percentage_above_base_capacity" yaml:"on_demand_percentage_above_base_capacity"` MaxPrice *float64 `json:"max_price" yaml:"max_price"` InstancePools *int64 `json:"instance_pools" yaml:"instance_pools"` - OnDemandBackup *bool `json:"on_demand_backup" yaml:"on_demand_backup"` } type Subnet struct { @@ -140,7 +148,7 @@ type InternalConfig struct { // Populated by operator OperatorMetadata - InstanceMetadata aws.InstanceMetadata `json:"instance_metadata"` + InstancesMetadata []aws.InstanceMetadata `json:"instance_metadata"` } // The bare minimum to identify a cluster @@ -164,6 +172,32 @@ func RegionValidator(region string) (string, error) { return region, nil } +func (cc *Config) DeepCopy() (Config, error) { + bytes, err := yaml.Marshal(cc) + if err != nil { + return Config{}, err + } + + deepCopied := Config{} + err = yaml.Unmarshal(bytes, &deepCopied) + if err != nil { + return Config{}, err + } + + return deepCopied, nil +} + +func (cc *Config) Hash() (string, error) { + bytes, err := yaml.Marshal(cc) + if err != nil { + return "", err + } + + hash := sha256.New() + hash.Write(bytes) + return hex.EncodeToString(hash.Sum(nil)), nil +} + var CoreConfigStructFieldValidations = []*cr.StructFieldValidation{ { StructField: "Provider", @@ -382,43 +416,122 @@ var CoreConfigStructFieldValidations = []*cr.StructFieldValidation{ var ManagedConfigStructFieldValidations = []*cr.StructFieldValidation{ { - StructField: "InstanceType", - StringValidation: &cr.StringValidation{ - Required: true, - MinLength: 1, - Validator: validateInstanceType, - }, - }, - { - StructField: "MinInstances", - Int64Validation: &cr.Int64Validation{ - Default: int64(1), - GreaterThanOrEqualTo: pointer.Int64(0), - }, - }, - { - StructField: "MaxInstances", - Int64Validation: &cr.Int64Validation{ - Default: int64(5), - GreaterThan: pointer.Int64(0), - }, - }, - { - StructField: "InstanceVolumeSize", - Int64Validation: &cr.Int64Validation{ - Default: 50, - GreaterThanOrEqualTo: pointer.Int64(20), // large enough to fit docker images and any other overhead - LessThanOrEqualTo: pointer.Int64(16384), - }, - }, - { - StructField: "InstanceVolumeType", - StringValidation: &cr.StringValidation{ - AllowedValues: VolumeTypesStrings(), - Default: GP2VolumeType.String(), - }, - Parser: func(str string) (interface{}, error) { - return VolumeTypeFromString(str), nil + StructField: "NodeGroups", + StructListValidation: &cr.StructListValidation{ + Required: true, + StructValidation: &cr.StructValidation{ + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "Name", + StringValidation: &cr.StringValidation{ + Required: true, + AlphaNumericDashUnderscore: true, + MaxLength: _maxNodeGroupLength, + }, + }, + { + StructField: "InstanceType", + StringValidation: &cr.StringValidation{ + Required: true, + MinLength: 1, + Validator: validateInstanceType, + }, + }, + { + StructField: "MinInstances", + Int64Validation: &cr.Int64Validation{ + Default: int64(1), + GreaterThanOrEqualTo: pointer.Int64(0), + }, + }, + { + StructField: "MaxInstances", + Int64Validation: &cr.Int64Validation{ + Default: int64(5), + GreaterThan: pointer.Int64(0), + }, + }, + { + StructField: "InstanceVolumeSize", + Int64Validation: &cr.Int64Validation{ + Default: 50, + GreaterThanOrEqualTo: pointer.Int64(20), // large enough to fit docker images and any other overhead + LessThanOrEqualTo: pointer.Int64(16384), + }, + }, + { + StructField: "InstanceVolumeType", + StringValidation: &cr.StringValidation{ + AllowedValues: VolumeTypesStrings(), + Default: GP2VolumeType.String(), + }, + Parser: func(str string) (interface{}, error) { + return VolumeTypeFromString(str), nil + }, + }, + { + StructField: "InstanceVolumeIOPS", + Int64PtrValidation: &cr.Int64PtrValidation{ + GreaterThanOrEqualTo: pointer.Int64(100), + LessThanOrEqualTo: pointer.Int64(64000), + AllowExplicitNull: true, + }, + }, + { + StructField: "Spot", + BoolValidation: &cr.BoolValidation{ + Default: false, + }, + }, + { + StructField: "SpotConfig", + StructValidation: &cr.StructValidation{ + DefaultNil: true, + AllowExplicitNull: true, + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "InstanceDistribution", + StringListValidation: &cr.StringListValidation{ + DisallowDups: true, + Validator: validateInstanceDistribution, + AllowExplicitNull: true, + }, + }, + { + StructField: "OnDemandBaseCapacity", + Int64PtrValidation: &cr.Int64PtrValidation{ + GreaterThanOrEqualTo: pointer.Int64(0), + AllowExplicitNull: true, + }, + }, + { + StructField: "OnDemandPercentageAboveBaseCapacity", + Int64PtrValidation: &cr.Int64PtrValidation{ + GreaterThanOrEqualTo: pointer.Int64(0), + LessThanOrEqualTo: pointer.Int64(100), + AllowExplicitNull: true, + }, + }, + { + StructField: "MaxPrice", + Float64PtrValidation: &cr.Float64PtrValidation{ + GreaterThan: pointer.Float64(0), + AllowExplicitNull: true, + }, + }, + { + StructField: "InstancePools", + Int64PtrValidation: &cr.Int64PtrValidation{ + GreaterThanOrEqualTo: pointer.Int64(1), + LessThanOrEqualTo: pointer.Int64(int64(_maxInstancePools)), + AllowExplicitNull: true, + }, + }, + }, + }, + }, + }, + }, }, }, { @@ -467,73 +580,6 @@ var ManagedConfigStructFieldValidations = []*cr.StructFieldValidation{ AllowExplicitNull: true, }, }, - { - StructField: "InstanceVolumeIOPS", - Int64PtrValidation: &cr.Int64PtrValidation{ - GreaterThanOrEqualTo: pointer.Int64(100), - LessThanOrEqualTo: pointer.Int64(64000), - AllowExplicitNull: true, - }, - }, - { - StructField: "Spot", - BoolValidation: &cr.BoolValidation{ - Default: false, - }, - }, - { - StructField: "SpotConfig", - StructValidation: &cr.StructValidation{ - DefaultNil: true, - AllowExplicitNull: true, - StructFieldValidations: []*cr.StructFieldValidation{ - { - StructField: "InstanceDistribution", - StringListValidation: &cr.StringListValidation{ - DisallowDups: true, - Validator: validateInstanceDistribution, - AllowExplicitNull: true, - }, - }, - { - StructField: "OnDemandBaseCapacity", - Int64PtrValidation: &cr.Int64PtrValidation{ - GreaterThanOrEqualTo: pointer.Int64(0), - AllowExplicitNull: true, - }, - }, - { - StructField: "OnDemandPercentageAboveBaseCapacity", - Int64PtrValidation: &cr.Int64PtrValidation{ - GreaterThanOrEqualTo: pointer.Int64(0), - LessThanOrEqualTo: pointer.Int64(100), - AllowExplicitNull: true, - }, - }, - { - StructField: "MaxPrice", - Float64PtrValidation: &cr.Float64PtrValidation{ - GreaterThan: pointer.Float64(0), - AllowExplicitNull: true, - }, - }, - { - StructField: "InstancePools", - Int64PtrValidation: &cr.Int64PtrValidation{ - GreaterThanOrEqualTo: pointer.Int64(1), - LessThanOrEqualTo: pointer.Int64(int64(_maxInstancePools)), - AllowExplicitNull: true, - }, - }, - { - StructField: "OnDemandBackup", - BoolPtrValidation: &cr.BoolPtrValidation{ - Default: pointer.Bool(true), - }, - }, - }, - }, - }, { StructField: "AvailabilityZones", StringListValidation: &cr.StringListValidation{ @@ -697,11 +743,45 @@ func (cc *CoreConfig) SQSNamePrefix() string { } // this validates the user-provided cluster config -func (cc *Config) Validate(awsClient *aws.Client) error { +func (cc *Config) Validate(awsClient *aws.Client, skipQuotaVerification bool) error { fmt.Print("verifying your configuration ...\n\n") - if cc.MinInstances > cc.MaxInstances { - return ErrorMinInstancesGreaterThanMax(cc.MinInstances, cc.MaxInstances) + numNodeGroups := len(cc.NodeGroups) + if numNodeGroups == 0 { + return ErrorNoNodeGroupSpecified() + } + if numNodeGroups > MaxNodePoolsOrGroups { + return ErrorMaxNumOfNodeGroupsReached(MaxNodePoolsOrGroups) + } + + ngNames := []string{} + instances := []aws.InstanceTypeRequests{} + for _, nodeGroup := range cc.NodeGroups { + if !slices.HasString(ngNames, nodeGroup.Name) { + ngNames = append(ngNames, nodeGroup.Name) + } else { + return errors.Wrap(ErrorDuplicateNodeGroupName(nodeGroup.Name), NodeGroupsKey) + } + + err := nodeGroup.validateNodeGroup(awsClient, cc.Region) + if err != nil { + return errors.Wrap(err, NodeGroupsKey, nodeGroup.Name) + } + + instances = append(instances, aws.InstanceTypeRequests{ + InstanceType: nodeGroup.InstanceType, + RequiredOnDemandInstances: nodeGroup.MaxPossibleOnDemandInstances(), + RequiredSpotInstances: nodeGroup.MaxPossibleSpotInstances(), + }) + } + + if !skipQuotaVerification { + if err := awsClient.VerifyInstanceQuota(instances); err != nil { + // Skip AWS errors, since some regions (e.g. eu-north-1) do not support this API + if !aws.IsAWSError(err) { + return errors.Wrap(err, NodeGroupsKey) + } + } } if len(cc.AvailabilityZones) > 0 && len(cc.Subnets) > 0 { @@ -754,11 +834,6 @@ func (cc *Config) Validate(awsClient *aws.Client) error { } } - primaryInstanceType := cc.InstanceType - if _, ok := aws.InstanceMetadatas[cc.Region][primaryInstanceType]; !ok { - return errors.Wrap(ErrorInstanceTypeNotSupportedInRegion(primaryInstanceType, cc.Region), InstanceTypeKey) - } - if cc.SSLCertificateARN != nil { exists, err := awsClient.DoesCertificateExist(*cc.SSLCertificateARN) if err != nil { @@ -770,28 +845,6 @@ func (cc *Config) Validate(awsClient *aws.Client) error { } } - // Throw error if IOPS defined for other storage than io1 - if cc.InstanceVolumeType != IO1VolumeType && cc.InstanceVolumeIOPS != nil { - return ErrorIOPSNotSupported(cc.InstanceVolumeType) - } - - if cc.InstanceVolumeType == IO1VolumeType && cc.InstanceVolumeIOPS != nil { - if *cc.InstanceVolumeIOPS > cc.InstanceVolumeSize*50 { - return ErrorIOPSTooLarge(*cc.InstanceVolumeIOPS, cc.InstanceVolumeSize) - } - } - - if aws.EBSMetadatas[cc.Region][cc.InstanceVolumeType.String()].IOPSConfigurable && cc.InstanceVolumeIOPS == nil { - cc.InstanceVolumeIOPS = pointer.Int64(libmath.MinInt64(cc.InstanceVolumeSize*50, 3000)) - } - - if err := awsClient.VerifyInstanceQuota(primaryInstanceType, cc.MaxPossibleOnDemandInstances(), cc.MaxPossibleSpotInstances()); err != nil { - // Skip AWS errors, since some regions (e.g. eu-north-1) do not support this API - if !aws.IsAWSError(err) { - return err - } - } - for tagName, tagValue := range cc.Tags { if strings.HasPrefix(tagName, "cortex.dev/") { if tagName != ClusterNameTag { @@ -814,31 +867,61 @@ func (cc *Config) Validate(awsClient *aws.Client) error { } } - var requiredVPCs int - if len(cc.Subnets) == 0 { - requiredVPCs = 1 + if !skipQuotaVerification { + var requiredVPCs int + if len(cc.Subnets) == 0 { + requiredVPCs = 1 + } + if err := awsClient.VerifyNetworkQuotas(1, cc.NATGateway != NoneNATGateway, cc.NATGateway == HighlyAvailableNATGateway, requiredVPCs, strset.FromSlice(cc.AvailabilityZones)); err != nil { + // Skip AWS errors, since some regions (e.g. eu-north-1) do not support this API + if !aws.IsAWSError(err) { + return err + } + } } - if err := awsClient.VerifyNetworkQuotas(1, cc.NATGateway != NoneNATGateway, cc.NATGateway == HighlyAvailableNATGateway, requiredVPCs, strset.FromSlice(cc.AvailabilityZones)); err != nil { - // Skip AWS errors, since some regions (e.g. eu-north-1) do not support this API - if !aws.IsAWSError(err) { - return err + + return nil +} + +func (ng *NodeGroup) validateNodeGroup(awsClient *aws.Client, region string) error { + if ng.MinInstances > ng.MaxInstances { + return ErrorMinInstancesGreaterThanMax(ng.MinInstances, ng.MaxInstances) + } + + primaryInstanceType := ng.InstanceType + if _, ok := aws.InstanceMetadatas[region][primaryInstanceType]; !ok { + return errors.Wrap(ErrorInstanceTypeNotSupportedInRegion(primaryInstanceType, region), InstanceTypeKey) + } + + // Throw error if IOPS defined for other storage than io1 + if ng.InstanceVolumeType != IO1VolumeType && ng.InstanceVolumeIOPS != nil { + return ErrorIOPSNotSupported(ng.InstanceVolumeType) + } + + if ng.InstanceVolumeType == IO1VolumeType && ng.InstanceVolumeIOPS != nil { + if *ng.InstanceVolumeIOPS > ng.InstanceVolumeSize*50 { + return ErrorIOPSTooLarge(*ng.InstanceVolumeIOPS, ng.InstanceVolumeSize) } } - if cc.Spot { - cc.FillEmptySpotFields() + if aws.EBSMetadatas[region][ng.InstanceVolumeType.String()].IOPSConfigurable && ng.InstanceVolumeIOPS == nil { + ng.InstanceVolumeIOPS = pointer.Int64(libmath.MinInt64(ng.InstanceVolumeSize*50, 3000)) + } + + if ng.Spot { + ng.FillEmptySpotFields(region) - primaryInstance := aws.InstanceMetadatas[cc.Region][primaryInstanceType] + primaryInstance := aws.InstanceMetadatas[region][primaryInstanceType] - for _, instanceType := range cc.SpotConfig.InstanceDistribution { + for _, instanceType := range ng.SpotConfig.InstanceDistribution { if instanceType == primaryInstanceType { continue } - if _, ok := aws.InstanceMetadatas[cc.Region][instanceType]; !ok { - return errors.Wrap(ErrorInstanceTypeNotSupportedInRegion(instanceType, cc.Region), SpotConfigKey, InstanceDistributionKey) + if _, ok := aws.InstanceMetadatas[region][instanceType]; !ok { + return errors.Wrap(ErrorInstanceTypeNotSupportedInRegion(instanceType, region), SpotConfigKey, InstanceDistributionKey) } - instanceMetadata := aws.InstanceMetadatas[cc.Region][instanceType] + instanceMetadata := aws.InstanceMetadatas[region][instanceType] err := CheckSpotInstanceCompatibility(primaryInstance, instanceMetadata) if err != nil { return errors.Wrap(err, SpotConfigKey, InstanceDistributionKey) @@ -846,17 +929,17 @@ func (cc *Config) Validate(awsClient *aws.Client) error { spotInstancePrice, awsErr := awsClient.SpotInstancePrice(instanceMetadata.Type) if awsErr == nil { - if err := CheckSpotInstancePriceCompatibility(primaryInstance, instanceMetadata, cc.SpotConfig.MaxPrice, spotInstancePrice); err != nil { + if err := CheckSpotInstancePriceCompatibility(primaryInstance, instanceMetadata, ng.SpotConfig.MaxPrice, spotInstancePrice); err != nil { return errors.Wrap(err, SpotConfigKey, InstanceDistributionKey) } } } - if cc.SpotConfig.OnDemandBaseCapacity != nil && *cc.SpotConfig.OnDemandBaseCapacity > cc.MaxInstances { - return ErrorOnDemandBaseCapacityGreaterThanMax(*cc.SpotConfig.OnDemandBaseCapacity, cc.MaxInstances) + if ng.SpotConfig.OnDemandBaseCapacity != nil && *ng.SpotConfig.OnDemandBaseCapacity > ng.MaxInstances { + return ErrorOnDemandBaseCapacityGreaterThanMax(*ng.SpotConfig.OnDemandBaseCapacity, ng.MaxInstances) } } else { - if cc.SpotConfig != nil { + if ng.SpotConfig != nil { return ErrorConfiguredWhenSpotIsNotEnabled(SpotConfigKey) } } @@ -971,10 +1054,6 @@ func AutoGenerateSpotConfig(spotConfig *SpotConfig, region string, instanceType spotConfig.OnDemandPercentageAboveBaseCapacity = pointer.Int64(0) } - if spotConfig.OnDemandBackup == nil { - spotConfig.OnDemandBackup = pointer.Bool(true) - } - if spotConfig.InstancePools == nil { if len(spotConfig.InstanceDistribution) < _maxInstancePools { spotConfig.InstancePools = pointer.Int64(int64(len(spotConfig.InstanceDistribution))) @@ -984,11 +1063,11 @@ func AutoGenerateSpotConfig(spotConfig *SpotConfig, region string, instanceType } } -func (cc *Config) FillEmptySpotFields() { - if cc.SpotConfig == nil { - cc.SpotConfig = &SpotConfig{} +func (ng *NodeGroup) FillEmptySpotFields(region string) { + if ng.SpotConfig == nil { + ng.SpotConfig = &SpotConfig{} } - AutoGenerateSpotConfig(cc.SpotConfig, cc.Region, cc.InstanceType) + AutoGenerateSpotConfig(ng.SpotConfig, region, ng.InstanceType) } func validateBucketNameOrEmpty(bucket string) (string, error) { @@ -1066,147 +1145,44 @@ func GetDefaults() (*Config, error) { return cc, nil } -func (cc *Config) MaxPossibleOnDemandInstances() int64 { - if cc.Spot == false || cc.SpotConfig == nil || cc.SpotConfig.OnDemandBackup == nil || *cc.SpotConfig.OnDemandBackup == true { - return cc.MaxInstances +func (ng *NodeGroup) MaxPossibleOnDemandInstances() int64 { + if ng.Spot == false || ng.SpotConfig == nil { + return ng.MaxInstances } - onDemandBaseCap, onDemandPctAboveBaseCap := cc.SpotConfigOnDemandValues() - return onDemandBaseCap + int64(math.Ceil(float64(onDemandPctAboveBaseCap)/100*float64(cc.MaxInstances-onDemandBaseCap))) + onDemandBaseCap, onDemandPctAboveBaseCap := ng.SpotConfigOnDemandValues() + return onDemandBaseCap + int64(math.Ceil(float64(onDemandPctAboveBaseCap)/100*float64(ng.MaxInstances-onDemandBaseCap))) } -func (cc *Config) MaxPossibleSpotInstances() int64 { - if cc.Spot == false { +func (ng *NodeGroup) MaxPossibleSpotInstances() int64 { + if ng.Spot == false { return 0 } - if cc.SpotConfig == nil { - return cc.MaxInstances + if ng.SpotConfig == nil { + return ng.MaxInstances } - onDemandBaseCap, onDemandPctAboveBaseCap := cc.SpotConfigOnDemandValues() - return cc.MaxInstances - onDemandBaseCap - int64(math.Floor(float64(onDemandPctAboveBaseCap)/100*float64(cc.MaxInstances-onDemandBaseCap))) + onDemandBaseCap, onDemandPctAboveBaseCap := ng.SpotConfigOnDemandValues() + return ng.MaxInstances - onDemandBaseCap - int64(math.Floor(float64(onDemandPctAboveBaseCap)/100*float64(ng.MaxInstances-onDemandBaseCap))) } -func (cc *Config) SpotConfigOnDemandValues() (int64, int64) { +func (ng *NodeGroup) SpotConfigOnDemandValues() (int64, int64) { // default OnDemandBaseCapacity is 0 var onDemandBaseCapacity int64 = 0 - if cc.SpotConfig.OnDemandBaseCapacity != nil { - onDemandBaseCapacity = *cc.SpotConfig.OnDemandBaseCapacity + if ng.SpotConfig.OnDemandBaseCapacity != nil { + onDemandBaseCapacity = *ng.SpotConfig.OnDemandBaseCapacity } // default OnDemandPercentageAboveBaseCapacity is 0 var onDemandPercentageAboveBaseCapacity int64 = 0 - if cc.SpotConfig.OnDemandPercentageAboveBaseCapacity != nil { - onDemandPercentageAboveBaseCapacity = *cc.SpotConfig.OnDemandPercentageAboveBaseCapacity + if ng.SpotConfig.OnDemandPercentageAboveBaseCapacity != nil { + onDemandPercentageAboveBaseCapacity = *ng.SpotConfig.OnDemandPercentageAboveBaseCapacity } return onDemandBaseCapacity, onDemandPercentageAboveBaseCapacity } -func (cc *InternalConfig) UserTable() table.KeyValuePairs { - var items *table.KeyValuePairs = &table.KeyValuePairs{} - - items.Add(APIVersionUserKey, cc.APIVersion) - items.AddAll(cc.Config.UserTable()) - - return *items -} - -func (cc *InternalConfig) UserStr() string { - return cc.UserTable().String() -} - -func (cc *CoreConfig) UserTable() table.KeyValuePairs { - var items table.KeyValuePairs - - items.Add(ClusterNameUserKey, cc.ClusterName) - items.Add(RegionUserKey, cc.Region) - items.Add(BucketUserKey, cc.Bucket) - items.Add(TelemetryUserKey, cc.Telemetry) - items.Add(ImageOperatorUserKey, cc.ImageOperator) - items.Add(ImageManagerUserKey, cc.ImageManager) - items.Add(ImageDownloaderUserKey, cc.ImageDownloader) - items.Add(ImageRequestMonitorUserKey, cc.ImageRequestMonitor) - items.Add(ImageClusterAutoscalerUserKey, cc.ImageClusterAutoscaler) - items.Add(ImageMetricsServerUserKey, cc.ImageMetricsServer) - items.Add(ImageInferentiaUserKey, cc.ImageInferentia) - items.Add(ImageNeuronRTDUserKey, cc.ImageNeuronRTD) - items.Add(ImageNvidiaUserKey, cc.ImageNvidia) - items.Add(ImageFluentBitUserKey, cc.ImageFluentBit) - items.Add(ImageIstioProxyUserKey, cc.ImageIstioProxy) - items.Add(ImageIstioPilotUserKey, cc.ImageIstioPilot) - items.Add(ImagePrometheusUserKey, cc.ImagePrometheus) - items.Add(ImagePrometheusConfigReloaderUserKey, cc.ImagePrometheusConfigReloader) - items.Add(ImagePrometheusOperatorUserKey, cc.ImagePrometheusOperator) - items.Add(ImagePrometheusStatsDExporterUserKey, cc.ImagePrometheusStatsDExporter) - items.Add(ImagePrometheusDCGMExporterUserKey, cc.ImagePrometheusDCGMExporter) - items.Add(ImagePrometheusKubeStateMetricsUserKey, cc.ImagePrometheusKubeStateMetrics) - items.Add(ImagePrometheusNodeExporterUserKey, cc.ImagePrometheusNodeExporter) - items.Add(ImageKubeRBACProxyUserKey, cc.ImageKubeRBACProxy) - items.Add(ImageGrafanaUserKey, cc.ImageGrafana) - items.Add(ImageEventExporterUserKey, cc.ImageEventExporter) - - return items -} - -func (mc *ManagedConfig) UserTable() table.KeyValuePairs { - var items table.KeyValuePairs - - if len(mc.AvailabilityZones) > 0 { - items.Add(AvailabilityZonesUserKey, mc.AvailabilityZones) - } - for _, subnetConfig := range mc.Subnets { - items.Add("subnet in "+subnetConfig.AvailabilityZone, subnetConfig.SubnetID) - } - items.Add(InstanceTypeUserKey, mc.InstanceType) - items.Add(MinInstancesUserKey, mc.MinInstances) - items.Add(MaxInstancesUserKey, mc.MaxInstances) - items.Add(TagsUserKey, s.ObjFlat(mc.Tags)) - if mc.SSLCertificateARN != nil { - items.Add(SSLCertificateARNUserKey, *mc.SSLCertificateARN) - } - items.Add(CortexPolicyARNUserKey, mc.CortexPolicyARN) - items.Add(IAMPolicyARNsUserKey, s.ObjFlatNoQuotes(mc.IAMPolicyARNs)) - - items.Add(InstanceVolumeSizeUserKey, mc.InstanceVolumeSize) - items.Add(InstanceVolumeTypeUserKey, mc.InstanceVolumeType) - items.Add(InstanceVolumeIOPSUserKey, mc.InstanceVolumeIOPS) - items.Add(SpotUserKey, s.YesNo(mc.Spot)) - if mc.Spot { - items.Add(InstanceDistributionUserKey, mc.SpotConfig.InstanceDistribution) - items.Add(OnDemandBaseCapacityUserKey, *mc.SpotConfig.OnDemandBaseCapacity) - items.Add(OnDemandPercentageAboveBaseCapacityUserKey, *mc.SpotConfig.OnDemandPercentageAboveBaseCapacity) - items.Add(MaxPriceUserKey, *mc.SpotConfig.MaxPrice) - items.Add(InstancePoolsUserKey, *mc.SpotConfig.InstancePools) - items.Add(OnDemandBackupUserKey, s.YesNo(*mc.SpotConfig.OnDemandBackup)) - } - items.Add(SubnetVisibilityUserKey, mc.SubnetVisibility) - items.Add(NATGatewayUserKey, mc.NATGateway) - items.Add(APILoadBalancerSchemeUserKey, mc.APILoadBalancerScheme) - items.Add(OperatorLoadBalancerSchemeUserKey, mc.OperatorLoadBalancerScheme) - if mc.VPCCIDR != nil { - items.Add(VPCCIDRKey, *mc.VPCCIDR) - } - - return items -} - -func (cc *Config) UserTable() table.KeyValuePairs { - var items *table.KeyValuePairs = &table.KeyValuePairs{} - items.AddAll(cc.CoreConfig.UserTable()) - - if cc.CoreConfig.IsManaged { - items.AddAll(cc.ManagedConfig.UserTable()) - } - - return *items -} - -func (cc *Config) UserStr() string { - return cc.UserTable().String() -} - func (cc *CoreConfig) TelemetryEvent() map[string]interface{} { event := map[string]interface{}{ "provider": types.AWSProviderType, @@ -1298,16 +1274,6 @@ func (cc *CoreConfig) TelemetryEvent() map[string]interface{} { func (mc *ManagedConfig) TelemetryEvent() map[string]interface{} { event := map[string]interface{}{} - - event["instance_type"] = mc.InstanceType - event["min_instances"] = mc.MinInstances - event["max_instances"] = mc.MaxInstances - event["instance_volume_size"] = mc.InstanceVolumeSize - event["instance_volume_type"] = mc.InstanceVolumeType - if mc.InstanceVolumeIOPS != nil { - event["instance_volume_iops._is_defined"] = true - event["instance_volume_iops"] = *mc.InstanceVolumeIOPS - } if len(mc.Tags) > 0 { event["tags._is_defined"] = true event["tags._len"] = len(mc.Tags) @@ -1340,35 +1306,82 @@ func (mc *ManagedConfig) TelemetryEvent() map[string]interface{} { if mc.VPCCIDR != nil { event["vpc_cidr._is_defined"] = true } - event["spot"] = mc.Spot - if mc.SpotConfig != nil { - event["spot_config._is_defined"] = true - if len(mc.SpotConfig.InstanceDistribution) > 0 { - event["spot_config.instance_distribution._is_defined"] = true - event["spot_config.instance_distribution._len"] = len(mc.SpotConfig.InstanceDistribution) - event["spot_config.instance_distribution"] = mc.SpotConfig.InstanceDistribution - } - if mc.SpotConfig.OnDemandBaseCapacity != nil { - event["spot_config.on_demand_base_capacity._is_defined"] = true - event["spot_config.on_demand_base_capacity"] = *mc.SpotConfig.OnDemandBaseCapacity - } - if mc.SpotConfig.OnDemandPercentageAboveBaseCapacity != nil { - event["spot_config.on_demand_percentage_above_base_capacity._is_defined"] = true - event["spot_config.on_demand_percentage_above_base_capacity"] = *mc.SpotConfig.OnDemandPercentageAboveBaseCapacity + + onDemandInstanceTypes := strset.New() + spotInstanceTypes := strset.New() + var totalMinSize, totalMaxSize int + + event["node_groups._len"] = len(mc.NodeGroups) + for _, ng := range mc.NodeGroups { + nodeGroupKey := func(field string) string { + lifecycle := "on_demand" + if ng.Spot { + lifecycle = "spot" + } + return fmt.Sprintf("node_groups.%s-%s.%s", ng.InstanceType, lifecycle, field) } - if mc.SpotConfig.MaxPrice != nil { - event["spot_config.max_price._is_defined"] = true - event["spot_config.max_price"] = *mc.SpotConfig.MaxPrice + event[nodeGroupKey("_is_defined")] = true + event[nodeGroupKey("name")] = ng.Name + event[nodeGroupKey("instance_type")] = ng.InstanceType + event[nodeGroupKey("min_instances")] = ng.MinInstances + event[nodeGroupKey("max_instances")] = ng.MaxInstances + event[nodeGroupKey("instance_volume_size")] = ng.InstanceVolumeSize + event[nodeGroupKey("instance_volume_type")] = ng.InstanceVolumeType + if ng.InstanceVolumeIOPS != nil { + event[nodeGroupKey("instance_volume_iops.is_defined")] = true + event[nodeGroupKey("instance_volume_iops")] = ng.InstanceVolumeIOPS } - if mc.SpotConfig.InstancePools != nil { - event["spot_config.instance_pools._is_defined"] = true - event["spot_config.instance_pools"] = *mc.SpotConfig.InstancePools + + event[nodeGroupKey("spot")] = ng.Spot + if !ng.Spot { + onDemandInstanceTypes.Add(ng.InstanceType) + } else { + spotInstanceTypes.Add(ng.InstanceType) } - if mc.SpotConfig.OnDemandBackup != nil { - event["spot_config.on_demand_backup._is_defined"] = true - event["spot_config.on_demand_backup"] = *mc.SpotConfig.OnDemandBackup + if ng.SpotConfig != nil { + event[nodeGroupKey("spot_config._is_defined")] = true + if len(ng.SpotConfig.InstanceDistribution) > 0 { + event[nodeGroupKey("spot_config.instance_distribution._is_defined")] = true + event[nodeGroupKey("spot_config.instance_distribution._len")] = len(ng.SpotConfig.InstanceDistribution) + event[nodeGroupKey("spot_config.instance_distribution")] = ng.SpotConfig.InstanceDistribution + spotInstanceTypes.Add(ng.SpotConfig.InstanceDistribution...) + } + if ng.SpotConfig.OnDemandBaseCapacity != nil { + event[nodeGroupKey("spot_config.on_demand_base_capacity._is_defined")] = true + event[nodeGroupKey("spot_config.on_demand_base_capacity")] = *ng.SpotConfig.OnDemandBaseCapacity + } + if ng.SpotConfig.OnDemandPercentageAboveBaseCapacity != nil { + event[nodeGroupKey("spot_config.on_demand_percentage_above_base_capacity._is_defined")] = true + event[nodeGroupKey("spot_config.on_demand_percentage_above_base_capacity")] = *ng.SpotConfig.OnDemandPercentageAboveBaseCapacity + } + if ng.SpotConfig.MaxPrice != nil { + event[nodeGroupKey("spot_config.max_price._is_defined")] = true + event[nodeGroupKey("spot_config.max_price")] = *ng.SpotConfig.MaxPrice + } + if ng.SpotConfig.InstancePools != nil { + event[nodeGroupKey("spot_config.instance_pools._is_defined")] = true + event[nodeGroupKey("spot_config.instance_pools")] = *ng.SpotConfig.InstancePools + } } + + totalMinSize += int(ng.MinInstances) + totalMaxSize += int(ng.MaxInstances) } + event["node_groups._total_min_size"] = totalMinSize + event["node_groups._total_max_size"] = totalMaxSize + event["node_groups._on_demand_instances"] = onDemandInstanceTypes.Slice() + event["node_groups._spot_instances"] = spotInstanceTypes.Slice() + event["node_groups._instances"] = strset.Union(onDemandInstanceTypes, spotInstanceTypes).Slice() + return event } + +func (mc *ManagedConfig) GetAllInstanceTypes() []string { + allInstanceTypes := strset.New() + for _, ng := range mc.NodeGroups { + allInstanceTypes.Add(ng.InstanceType) + } + + return allInstanceTypes.Slice() +} diff --git a/pkg/types/clusterconfig/cluster_config_gcp.go b/pkg/types/clusterconfig/cluster_config_gcp.go index d752b1ebf7..1469a7d9dd 100644 --- a/pkg/types/clusterconfig/cluster_config_gcp.go +++ b/pkg/types/clusterconfig/cluster_config_gcp.go @@ -22,16 +22,21 @@ import ( "github.com/cortexlabs/cortex/pkg/consts" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/gcp" "github.com/cortexlabs/cortex/pkg/lib/hash" "github.com/cortexlabs/cortex/pkg/lib/pointer" "github.com/cortexlabs/cortex/pkg/lib/prompt" + "github.com/cortexlabs/cortex/pkg/lib/sets/strset" "github.com/cortexlabs/cortex/pkg/lib/slices" - s "github.com/cortexlabs/cortex/pkg/lib/strings" - "github.com/cortexlabs/cortex/pkg/lib/table" "github.com/cortexlabs/cortex/pkg/types" ) +var ( + _maxNodePoolLengthWithPrefix = 19 // node pool length name limit on GKE + _maxNodePoolLength = _maxNodePoolLengthWithPrefix - len("cx-wd-") // or cx-ws- +) + type GCPCoreConfig struct { Provider types.ProviderType `json:"provider" yaml:"provider"` Project string `json:"project" yaml:"project"` @@ -65,17 +70,21 @@ type GCPCoreConfig struct { } type GCPManagedConfig struct { - InstanceType string `json:"instance_type" yaml:"instance_type"` - AcceleratorType *string `json:"accelerator_type,omitempty" yaml:"accelerator_type,omitempty"` - AcceleratorsPerInstance *int64 `json:"accelerators_per_instance,omitempty" yaml:"accelerators_per_instance,omitempty"` + NodePools []*NodePool `json:"node_pools" yaml:"node_pools"` Network *string `json:"network,omitempty" yaml:"network,omitempty"` Subnet *string `json:"subnet,omitempty" yaml:"subnet,omitempty"` APILoadBalancerScheme LoadBalancerScheme `json:"api_load_balancer_scheme" yaml:"api_load_balancer_scheme"` OperatorLoadBalancerScheme LoadBalancerScheme `json:"operator_load_balancer_scheme" yaml:"operator_load_balancer_scheme"` - MinInstances int64 `json:"min_instances" yaml:"min_instances"` - MaxInstances int64 `json:"max_instances" yaml:"max_instances"` - Preemptible bool `json:"preemptible" yaml:"preemptible"` - OnDemandBackup bool `json:"on_demand_backup" yaml:"on_demand_backup"` +} + +type NodePool struct { + Name string `json:"name" yaml:"name"` + InstanceType string `json:"instance_type" yaml:"instance_type"` + AcceleratorType *string `json:"accelerator_type,omitempty" yaml:"accelerator_type,omitempty"` + AcceleratorsPerInstance *int64 `json:"accelerators_per_instance,omitempty" yaml:"accelerators_per_instance,omitempty"` + MinInstances int64 `json:"min_instances" yaml:"min_instances"` + MaxInstances int64 `json:"max_instances" yaml:"max_instances"` + Preemptible bool `json:"preemptible" yaml:"preemptible"` } type GCPConfig struct { @@ -298,29 +307,67 @@ var GCPCoreConfigStructFieldValidations = []*cr.StructFieldValidation{ var GCPManagedConfigStructFieldValidations = []*cr.StructFieldValidation{ { - StructField: "InstanceType", - StringValidation: &cr.StringValidation{ + StructField: "NodePools", + StructListValidation: &cr.StructListValidation{ Required: true, - }, - }, - { - StructField: "AcceleratorType", - StringPtrValidation: &cr.StringPtrValidation{ - AllowExplicitNull: true, - }, - }, - { - StructField: "AcceleratorsPerInstance", - Int64PtrValidation: &cr.Int64PtrValidation{ - AllowExplicitNull: true, - }, - DefaultDependentFields: []string{"AcceleratorType"}, - DefaultDependentFieldsFunc: func(vals []interface{}) interface{} { - acceleratorType := vals[0].(*string) - if acceleratorType == nil { - return nil - } - return pointer.Int64(1) + StructValidation: &cr.StructValidation{ + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "Name", + StringValidation: &cr.StringValidation{ + Required: true, + AlphaNumericDashUnderscore: true, + MaxLength: _maxNodePoolLength, + }, + }, + { + StructField: "InstanceType", + StringValidation: &cr.StringValidation{ + Required: true, + }, + }, + { + StructField: "AcceleratorType", + StringPtrValidation: &cr.StringPtrValidation{ + AllowExplicitNull: true, + }, + }, + { + StructField: "AcceleratorsPerInstance", + Int64PtrValidation: &cr.Int64PtrValidation{ + AllowExplicitNull: true, + }, + DefaultDependentFields: []string{"AcceleratorType"}, + DefaultDependentFieldsFunc: func(vals []interface{}) interface{} { + acceleratorType := vals[0].(*string) + if acceleratorType == nil { + return nil + } + return pointer.Int64(1) + }, + }, + { + StructField: "MinInstances", + Int64Validation: &cr.Int64Validation{ + Default: int64(1), + GreaterThanOrEqualTo: pointer.Int64(0), + }, + }, + { + StructField: "MaxInstances", + Int64Validation: &cr.Int64Validation{ + Default: int64(5), + GreaterThan: pointer.Int64(0), + }, + }, + { + StructField: "Preemptible", + BoolValidation: &cr.BoolValidation{ + Default: false, + }, + }, + }, + }, }, }, { @@ -355,34 +402,6 @@ var GCPManagedConfigStructFieldValidations = []*cr.StructFieldValidation{ return LoadBalancerSchemeFromString(str), nil }, }, - { - StructField: "MinInstances", - Int64Validation: &cr.Int64Validation{ - Default: int64(1), - GreaterThanOrEqualTo: pointer.Int64(0), - }, - }, - { - StructField: "MaxInstances", - Int64Validation: &cr.Int64Validation{ - Default: int64(5), - GreaterThan: pointer.Int64(0), - }, - }, - { - StructField: "Preemptible", - BoolValidation: &cr.BoolValidation{ - Default: false, - }, - }, - { - StructField: "OnDemandBackup", - DefaultDependentFields: []string{"Preemptible"}, - DefaultDependentFieldsFunc: func(vals []interface{}) interface{} { - return vals[0].(bool) - }, - BoolValidation: &cr.BoolValidation{}, - }, } var GCPAccessValidation = &cr.StructValidation{ @@ -509,28 +528,55 @@ func (cc *GCPConfig) Validate(GCP *gcp.Client) error { cc.Bucket = GCPBucketName(cc.ClusterName, cc.Project, cc.Zone) } - if validInstanceType, err := GCP.IsInstanceTypeAvailable(cc.InstanceType, cc.Zone); err != nil { + numNodePools := len(cc.NodePools) + if numNodePools == 0 { + return ErrorGCPNoNodePoolSpecified() + } + if numNodePools > MaxNodePoolsOrGroups { + return ErrorGCPMaxNumOfNodePoolsReached(MaxNodePoolsOrGroups) + } + + npNames := []string{} + for _, nodePool := range cc.NodePools { + if !slices.HasString(npNames, nodePool.Name) { + npNames = append(npNames, nodePool.Name) + } else { + return errors.Wrap(ErrorGCPDuplicateNodePoolName(nodePool.Name), NodePoolsKey) + } + + err := nodePool.validateNodePool(GCP, cc.Zone) + if err != nil { + return errors.Wrap(err, NodeGroupsKey, nodePool.Name) + } + + } + + return nil +} + +func (np *NodePool) validateNodePool(GCP *gcp.Client, zone string) error { + if validInstanceType, err := GCP.IsInstanceTypeAvailable(np.InstanceType, zone); err != nil { return err } else if !validInstanceType { - instanceTypes, err := GCP.GetAvailableInstanceTypes(cc.Zone) + instanceTypes, err := GCP.GetAvailableInstanceTypes(zone) if err != nil { return err } - return ErrorGCPInvalidInstanceType(cc.InstanceType, instanceTypes...) + return ErrorGCPInvalidInstanceType(np.InstanceType, instanceTypes...) } - if cc.AcceleratorType == nil && cc.AcceleratorsPerInstance != nil { + if np.AcceleratorType == nil && np.AcceleratorsPerInstance != nil { return ErrorDependentFieldMustBeSpecified(AcceleratorsPerInstanceKey, AcceleratorTypeKey) } - if cc.AcceleratorType != nil { - if cc.AcceleratorsPerInstance == nil { + if np.AcceleratorType != nil { + if np.AcceleratorsPerInstance == nil { return ErrorDependentFieldMustBeSpecified(AcceleratorTypeKey, AcceleratorsPerInstanceKey) } - if validAccelerator, err := GCP.IsAcceleratorTypeAvailable(*cc.AcceleratorType, cc.Zone); err != nil { + if validAccelerator, err := GCP.IsAcceleratorTypeAvailable(*np.AcceleratorType, zone); err != nil { return err } else if !validAccelerator { - availableAcceleratorsInZone, err := GCP.GetAvailableAcceleratorTypes(cc.Zone) + availableAcceleratorsInZone, err := GCP.GetAvailableAcceleratorTypes(zone) if err != nil { return err } @@ -540,120 +586,34 @@ func (cc *GCPConfig) Validate(GCP *gcp.Client) error { } var availableZonesForAccelerator []string - if slices.HasString(allAcceleratorTypes, *cc.AcceleratorType) { - availableZonesForAccelerator, err = GCP.GetAvailableZonesForAccelerator(*cc.AcceleratorType) + if slices.HasString(allAcceleratorTypes, *np.AcceleratorType) { + availableZonesForAccelerator, err = GCP.GetAvailableZonesForAccelerator(*np.AcceleratorType) if err != nil { return err } } - return ErrorGCPInvalidAcceleratorType(*cc.AcceleratorType, cc.Zone, availableAcceleratorsInZone, availableZonesForAccelerator) + return ErrorGCPInvalidAcceleratorType(*np.AcceleratorType, zone, availableAcceleratorsInZone, availableZonesForAccelerator) } // according to https://cloud.google.com/kubernetes-engine/docs/how-to/gpus var compatibleInstances []string var err error - if strings.HasSuffix(*cc.AcceleratorType, "a100") { - compatibleInstances, err = GCP.GetInstanceTypesWithPrefix("a2", cc.Zone) + if strings.HasSuffix(*np.AcceleratorType, "a100") { + compatibleInstances, err = GCP.GetInstanceTypesWithPrefix("a2", zone) } else { - compatibleInstances, err = GCP.GetInstanceTypesWithPrefix("n1", cc.Zone) + compatibleInstances, err = GCP.GetInstanceTypesWithPrefix("n1", zone) } if err != nil { return err } - if !slices.HasString(compatibleInstances, cc.InstanceType) { - return ErrorGCPIncompatibleInstanceTypeWithAccelerator(cc.InstanceType, *cc.AcceleratorType, cc.Zone, compatibleInstances) + if !slices.HasString(compatibleInstances, np.InstanceType) { + return ErrorGCPIncompatibleInstanceTypeWithAccelerator(np.InstanceType, *np.AcceleratorType, zone, compatibleInstances) } } - if !cc.Preemptible && cc.OnDemandBackup { - return ErrorFieldConfigurationDependentOnCondition(OnDemandBackupKey, s.Bool(cc.OnDemandBackup), PreemptibleKey, s.Bool(cc.Preemptible)) - } - return nil } -func (cc *InternalGCPConfig) UserTable() table.KeyValuePairs { - var items table.KeyValuePairs - - items.Add(APIVersionUserKey, cc.APIVersion) - items.AddAll(cc.GCPConfig.UserTable()) - return items -} - -func (cc *InternalGCPConfig) UserStr() string { - return cc.UserTable().String() -} - -func (cc *GCPCoreConfig) UserTable() table.KeyValuePairs { - var items table.KeyValuePairs - - items.Add(ClusterNameUserKey, cc.ClusterName) - items.Add(ProjectUserKey, cc.Project) - items.Add(ZoneUserKey, cc.Zone) - items.Add(TelemetryUserKey, cc.Telemetry) - items.Add(ImageOperatorUserKey, cc.ImageOperator) - items.Add(ImageManagerUserKey, cc.ImageManager) - items.Add(ImageDownloaderUserKey, cc.ImageDownloader) - items.Add(ImageRequestMonitorUserKey, cc.ImageRequestMonitor) - items.Add(ImageClusterAutoscalerUserKey, cc.ImageClusterAutoscaler) - items.Add(ImageFluentBitUserKey, cc.ImageFluentBit) - items.Add(ImageIstioProxyUserKey, cc.ImageIstioProxy) - items.Add(ImageIstioPilotUserKey, cc.ImageIstioPilot) - items.Add(ImageGooglePauseUserKey, cc.ImageGooglePause) - items.Add(ImagePrometheusUserKey, cc.ImagePrometheus) - items.Add(ImagePrometheusConfigReloaderUserKey, cc.ImagePrometheusConfigReloader) - items.Add(ImagePrometheusOperatorUserKey, cc.ImagePrometheusOperator) - items.Add(ImagePrometheusStatsDExporterUserKey, cc.ImagePrometheusStatsDExporter) - items.Add(ImagePrometheusDCGMExporterUserKey, cc.ImagePrometheusDCGMExporter) - items.Add(ImagePrometheusKubeStateMetricsUserKey, cc.ImagePrometheusKubeStateMetrics) - items.Add(ImagePrometheusNodeExporterUserKey, cc.ImagePrometheusNodeExporter) - items.Add(ImageKubeRBACProxyUserKey, cc.ImageKubeRBACProxy) - items.Add(ImageGrafanaUserKey, cc.ImageGrafana) - items.Add(ImageEventExporterUserKey, cc.ImageEventExporter) - - return items -} - -func (cc *GCPManagedConfig) UserTable() table.KeyValuePairs { - var items table.KeyValuePairs - - items.Add(InstanceTypeUserKey, cc.InstanceType) - items.Add(MinInstancesUserKey, cc.MinInstances) - items.Add(MaxInstancesUserKey, cc.MaxInstances) - if cc.AcceleratorType != nil { - items.Add(AcceleratorTypeUserKey, *cc.AcceleratorType) - } - if cc.AcceleratorsPerInstance != nil { - items.Add(AcceleratorsPerInstanceUserKey, *cc.AcceleratorsPerInstance) - } - items.Add(PreemptibleUserKey, s.YesNo(cc.Preemptible)) - items.Add(OnDemandBackupUserKey, s.YesNo(cc.OnDemandBackup)) - if cc.Network != nil { - items.Add(NetworkUserKey, *cc.Network) - } - if cc.Subnet != nil { - items.Add(SubnetUserKey, *cc.Subnet) - } - items.Add(APILoadBalancerSchemeUserKey, cc.APILoadBalancerScheme) - items.Add(OperatorLoadBalancerSchemeUserKey, cc.OperatorLoadBalancerScheme) - - return items -} - -func (cc *GCPConfig) UserTable() table.KeyValuePairs { - items := &table.KeyValuePairs{} - items.AddAll(cc.GCPCoreConfig.UserTable()) - if cc.GCPCoreConfig.IsManaged { - items.AddAll(cc.GCPManagedConfig.UserTable()) - } - - return *items -} - -func (cc *GCPConfig) UserStr() string { - return cc.UserTable().String() -} - func (cc *GCPCoreConfig) TelemetryEvent() map[string]interface{} { event := map[string]interface{}{ "provider": types.GCPProviderType, @@ -736,15 +696,51 @@ func (cc *GCPCoreConfig) TelemetryEvent() map[string]interface{} { func (cc *GCPManagedConfig) TelemetryEvent() map[string]interface{} { event := map[string]interface{}{} - event["instance_type"] = cc.InstanceType - if cc.AcceleratorType != nil { - event["accelerator_type._is_defined"] = true - event["accelerator_type"] = *cc.AcceleratorType - } - if cc.AcceleratorsPerInstance != nil { - event["accelerators_per_instance._is_defined"] = true - event["accelerators_per_instance"] = *cc.AcceleratorsPerInstance + onDemandInstanceTypes := strset.New() + preemptibleInstanceTypes := strset.New() + var totalMinSize, totalMaxSize int + + event["node_pools._len"] = len(cc.NodePools) + for _, np := range cc.NodePools { + nodePoolKey := func(field string) string { + lifecycle := "on_demand" + if np.Preemptible { + lifecycle = "preemptible" + } + return fmt.Sprintf("node_pools.%s-%s.%s", np.InstanceType, lifecycle, field) + } + event[nodePoolKey("_is_defined")] = true + event[nodePoolKey("name")] = np.Name + event[nodePoolKey("instance_type")] = np.InstanceType + event[nodePoolKey("min_instances")] = np.MinInstances + event[nodePoolKey("max_instances")] = np.MaxInstances + + if !np.Preemptible { + onDemandInstanceTypes.Add(np.InstanceType) + } else { + preemptibleInstanceTypes.Add(np.InstanceType) + } + if np.AcceleratorType != nil { + event[nodePoolKey("accelerator_type._is_defined")] = true + event[nodePoolKey("accelerator_type")] = *np.AcceleratorType + } + if np.AcceleratorsPerInstance != nil { + event[nodePoolKey("accelerators_per_instance._is_defined")] = true + event[nodePoolKey("accelerators_per_instance")] = *np.AcceleratorsPerInstance + } + + event[nodePoolKey("preemptible")] = np.Preemptible + + totalMinSize += int(np.MinInstances) + totalMaxSize += int(np.MaxInstances) } + + event["node_pools._total_min_size"] = totalMinSize + event["node_pools._total_max_size"] = totalMaxSize + event["node_pools._on_demand_instances"] = onDemandInstanceTypes.Slice() + event["node_pools._spot_instances"] = preemptibleInstanceTypes.Slice() + event["node_pools._instances"] = strset.Union(onDemandInstanceTypes, preemptibleInstanceTypes).Slice() + if cc.Network != nil { event["network._is_defined"] = true } @@ -753,11 +749,6 @@ func (cc *GCPManagedConfig) TelemetryEvent() map[string]interface{} { } event["api_load_balancer_scheme"] = cc.APILoadBalancerScheme event["operator_load_balancer_scheme"] = cc.OperatorLoadBalancerScheme - event["min_instances"] = cc.MinInstances - event["max_instances"] = cc.MaxInstances - - event["preemptible"] = cc.Preemptible - event["on_demand_backup"] = cc.OnDemandBackup return event } diff --git a/pkg/types/clusterconfig/config_key.go b/pkg/types/clusterconfig/config_key.go index 920b558b50..0b442eb776 100644 --- a/pkg/types/clusterconfig/config_key.go +++ b/pkg/types/clusterconfig/config_key.go @@ -17,94 +17,80 @@ limitations under the License. package clusterconfig const ( - ProviderKey = "provider" + NodeGroupsKey = "node_groups" + NodePoolsKey = "node_pools" InstanceTypeKey = "instance_type" AcceleratorTypeKey = "accelerator_type" AcceleratorsPerInstanceKey = "accelerators_per_instance" - NetworkKey = "network" - SubnetKey = "subnet" MinInstancesKey = "min_instances" MaxInstancesKey = "max_instances" - TagsKey = "tags" - InstanceVolumeSizeKey = "instance_volume_size" - InstanceVolumeTypeKey = "instance_volume_type" - InstanceVolumeIOPSKey = "instance_volume_iops" + PreemptibleKey = "preemptible" SpotKey = "spot" SpotConfigKey = "spot_config" - PreemptibleKey = "preemptible" InstanceDistributionKey = "instance_distribution" OnDemandBaseCapacityKey = "on_demand_base_capacity" OnDemandPercentageAboveBaseCapacityKey = "on_demand_percentage_above_base_capacity" - MaxPriceKey = "max_price" + InstanceVolumeSizeKey = "instance_volume_size" + InstanceVolumeTypeKey = "instance_volume_type" + InstanceVolumeIOPSKey = "instance_volume_iops" InstancePoolsKey = "instance_pools" - OnDemandBackupKey = "on_demand_backup" - ClusterNameKey = "cluster_name" - RegionKey = "region" - ZoneKey = "zone" - ProjectKey = "project" - AvailabilityZonesKey = "availability_zones" - SubnetsKey = "subnets" - AvailabilityZoneKey = "availability_zone" - SubnetIDKey = "subnet_id" - SSLCertificateARNKey = "ssl_certificate_arn" - CortexPolicyARNKey = "cortex_policy_arn" - IAMPolicyARNsKey = "iam_policy_arns" - BucketKey = "bucket" - SubnetVisibilityKey = "subnet_visibility" - NATGatewayKey = "nat_gateway" - APILoadBalancerSchemeKey = "api_load_balancer_scheme" - OperatorLoadBalancerSchemeKey = "operator_load_balancer_scheme" - VPCCIDRKey = "vpc_cidr" - TelemetryKey = "telemetry" - ImageOperatorKey = "image_operator" - ImageManagerKey = "image_manager" - ImageDownloaderKey = "image_downloader" - ImageRequestMonitorKey = "image_request_monitor" - ImageClusterAutoscalerKey = "image_cluster_autoscaler" - ImageMetricsServerKey = "image_metrics_server" - ImageInferentiaKey = "image_inferentia" - ImageNeuronRTDKey = "image_neuron_rtd" - ImageNvidiaKey = "image_nvidia" - ImageFluentBitKey = "image_fluent_bit" - ImageIstioProxyKey = "image_istio_proxy" - ImageIstioPilotKey = "image_istio_pilot" - ImageGooglePauseKey = "image_google_pause" - ImagePrometheusKey = "image_prometheus" - ImagePrometheusConfigReloaderKey = "image_prometheus_config_reloader" - ImagePrometheusOperatorKey = "image_prometheus_operator" - ImagePrometheusStatsDExporterKey = "image_prometheus_statsd_exporter" - ImagePrometheusDCGMExporterKey = "image_prometheus_dcgm_exporter" - ImagePrometheusKubeStateMetricsKey = "image_prometheus_kube_state_metrics" - ImagePrometheusNodeExporterKey = "image_prometheus_node_exporter" - ImageKubeRBACProxyKey = "image_kube_rbac_proxy" - ImageGrafanaKey = "image_grafana" - ImageEventExporterKey = "image_event_exporter" + MaxPriceKey = "max_price" + + ProviderKey = "provider" + NetworkKey = "network" + SubnetKey = "subnet" + TagsKey = "tags" + ClusterNameKey = "cluster_name" + RegionKey = "region" + ZoneKey = "zone" + ProjectKey = "project" + AvailabilityZonesKey = "availability_zones" + SubnetsKey = "subnets" + AvailabilityZoneKey = "availability_zone" + SubnetIDKey = "subnet_id" + SSLCertificateARNKey = "ssl_certificate_arn" + CortexPolicyARNKey = "cortex_policy_arn" + IAMPolicyARNsKey = "iam_policy_arns" + BucketKey = "bucket" + SubnetVisibilityKey = "subnet_visibility" + NATGatewayKey = "nat_gateway" + APILoadBalancerSchemeKey = "api_load_balancer_scheme" + OperatorLoadBalancerSchemeKey = "operator_load_balancer_scheme" + VPCCIDRKey = "vpc_cidr" + TelemetryKey = "telemetry" + ImageOperatorKey = "image_operator" + ImageManagerKey = "image_manager" + ImageDownloaderKey = "image_downloader" + ImageRequestMonitorKey = "image_request_monitor" + ImageClusterAutoscalerKey = "image_cluster_autoscaler" + ImageMetricsServerKey = "image_metrics_server" + ImageInferentiaKey = "image_inferentia" + ImageNeuronRTDKey = "image_neuron_rtd" + ImageNvidiaKey = "image_nvidia" + ImageFluentBitKey = "image_fluent_bit" + ImageIstioProxyKey = "image_istio_proxy" + ImageIstioPilotKey = "image_istio_pilot" + ImageGooglePauseKey = "image_google_pause" + ImagePrometheusKey = "image_prometheus" + ImagePrometheusConfigReloaderKey = "image_prometheus_config_reloader" + ImagePrometheusOperatorKey = "image_prometheus_operator" + ImagePrometheusStatsDExporterKey = "image_prometheus_statsd_exporter" + ImagePrometheusDCGMExporterKey = "image_prometheus_dcgm_exporter" + ImagePrometheusKubeStateMetricsKey = "image_prometheus_kube_state_metrics" + ImagePrometheusNodeExporterKey = "image_prometheus_node_exporter" + ImageKubeRBACProxyKey = "image_kube_rbac_proxy" + ImageGrafanaKey = "image_grafana" + ImageEventExporterKey = "image_event_exporter" // User facing string - ProviderUserKey = "provider" - APIVersionUserKey = "cluster version" - ClusterNameUserKey = "cluster name" - RegionUserKey = "aws region" - ZoneUserKey = "gcp zone" - ProjectUserKey = "gcp project" - AvailabilityZonesUserKey = "availability zones" - SubnetsUserKey = "subnets" - AvailabilityZoneUserKey = "availability zone" - SubnetIDUserKey = "subnet id" - SSLCertificateARNUserKey = "ssl certificate arn" - CortexPolicyARNUserKey = "cortex policy arn" - IAMPolicyARNsUserKey = "iam policy arns" - BucketUserKey = "s3 bucket" + NodeGroupsUserKey = "node groups" SpotUserKey = "use spot instances" PreemptibleUserKey = "use preemptible instances" InstanceTypeUserKey = "instance type" AcceleratorTypeUserKey = "accelerator type" AcceleratorsPerInstanceUserKey = "accelerators per instance" - NetworkUserKey = "network" - SubnetUserKey = "subnet" MinInstancesUserKey = "min instances" MaxInstancesUserKey = "max instances" - TagsUserKey = "tags" InstanceVolumeSizeUserKey = "instance volume size (Gi)" InstanceVolumeTypeUserKey = "instance volume type" InstanceVolumeIOPSUserKey = "instance volume iops" @@ -113,34 +99,51 @@ const ( OnDemandPercentageAboveBaseCapacityUserKey = "spot on demand percentage above base capacity" MaxPriceUserKey = "spot max price ($ per hour)" InstancePoolsUserKey = "spot instance pools" - OnDemandBackupUserKey = "on demand backup" - SubnetVisibilityUserKey = "subnet visibility" - NATGatewayUserKey = "nat gateway" - APILoadBalancerSchemeUserKey = "api load balancer scheme" - OperatorLoadBalancerSchemeUserKey = "operator load balancer scheme" - VPCCIDRUserKey = "vpc cidr" - TelemetryUserKey = "telemetry" - ImageOperatorUserKey = "operator image" - ImageManagerUserKey = "manager image" - ImageDownloaderUserKey = "downloader image" - ImageRequestMonitorUserKey = "request monitor image" - ImageClusterAutoscalerUserKey = "cluster autoscaler image" - ImageMetricsServerUserKey = "metrics server image" - ImageInferentiaUserKey = "inferentia image" - ImageNeuronRTDUserKey = "neuron rtd image" - ImageNvidiaUserKey = "nvidia image" - ImageFluentBitUserKey = "fluent-bit image" - ImageIstioProxyUserKey = "istio proxy image" - ImageIstioPilotUserKey = "istio pilot image" - ImageGooglePauseUserKey = "google pause image" - ImagePrometheusUserKey = "prometheus image" - ImagePrometheusConfigReloaderUserKey = "prometheus config reloader image" - ImagePrometheusOperatorUserKey = "prometheus operator image" - ImagePrometheusStatsDExporterUserKey = "prometheus statsd exporter image" - ImagePrometheusDCGMExporterUserKey = "prometheus dcgm exporter image" - ImagePrometheusKubeStateMetricsUserKey = "prometheus kube-state-metrics image" - ImagePrometheusNodeExporterUserKey = "prometheus node exporter image" - ImageKubeRBACProxyUserKey = "kube rbac proxy image" - ImageGrafanaUserKey = "grafana image" - ImageEventExporterUserKey = "event exporter image" + + ProviderUserKey = "provider" + APIVersionUserKey = "cluster version" + ClusterNameUserKey = "cluster name" + RegionUserKey = "aws region" + ZoneUserKey = "gcp zone" + ProjectUserKey = "gcp project" + AvailabilityZonesUserKey = "availability zones" + AvailabilityZoneUserKey = "availability zone" + SubnetsUserKey = "subnets" + SubnetIDUserKey = "subnet id" + TagsUserKey = "tags" + SSLCertificateARNUserKey = "ssl certificate arn" + CortexPolicyARNUserKey = "cortex policy arn" + IAMPolicyARNsUserKey = "iam policy arns" + BucketUserKey = "s3 bucket" + NetworkUserKey = "network" + SubnetUserKey = "subnet" + SubnetVisibilityUserKey = "subnet visibility" + NATGatewayUserKey = "nat gateway" + APILoadBalancerSchemeUserKey = "api load balancer scheme" + OperatorLoadBalancerSchemeUserKey = "operator load balancer scheme" + VPCCIDRUserKey = "vpc cidr" + TelemetryUserKey = "telemetry" + ImageOperatorUserKey = "operator image" + ImageManagerUserKey = "manager image" + ImageDownloaderUserKey = "downloader image" + ImageRequestMonitorUserKey = "request monitor image" + ImageClusterAutoscalerUserKey = "cluster autoscaler image" + ImageMetricsServerUserKey = "metrics server image" + ImageInferentiaUserKey = "inferentia image" + ImageNeuronRTDUserKey = "neuron rtd image" + ImageNvidiaUserKey = "nvidia image" + ImageFluentBitUserKey = "fluent-bit image" + ImageIstioProxyUserKey = "istio proxy image" + ImageIstioPilotUserKey = "istio pilot image" + ImageGooglePauseUserKey = "google pause image" + ImagePrometheusUserKey = "prometheus image" + ImagePrometheusConfigReloaderUserKey = "prometheus config reloader image" + ImagePrometheusOperatorUserKey = "prometheus operator image" + ImagePrometheusStatsDExporterUserKey = "prometheus statsd exporter image" + ImagePrometheusDCGMExporterUserKey = "prometheus dcgm exporter image" + ImagePrometheusKubeStateMetricsUserKey = "prometheus kube-state-metrics image" + ImagePrometheusNodeExporterUserKey = "prometheus node exporter image" + ImageKubeRBACProxyUserKey = "kube rbac proxy image" + ImageGrafanaUserKey = "grafana image" + ImageEventExporterUserKey = "event exporter image" ) diff --git a/pkg/types/clusterconfig/errors.go b/pkg/types/clusterconfig/errors.go index a9f39b89b5..a06c9b7142 100644 --- a/pkg/types/clusterconfig/errors.go +++ b/pkg/types/clusterconfig/errors.go @@ -28,44 +28,51 @@ import ( ) const ( - ErrInvalidRegion = "clusterconfig.invalid_region" - ErrInstanceTypeTooSmall = "clusterconfig.instance_type_too_small" - ErrMinInstancesGreaterThanMax = "clusterconfig.min_instances_greater_than_max" - ErrInstanceTypeNotSupportedInRegion = "clusterconfig.instance_type_not_supported_in_region" - ErrIncompatibleSpotInstanceTypeMemory = "clusterconfig.incompatible_spot_instance_type_memory" - ErrIncompatibleSpotInstanceTypeCPU = "clusterconfig.incompatible_spot_instance_type_cpu" - ErrIncompatibleSpotInstanceTypeGPU = "clusterconfig.incompatible_spot_instance_type_gpu" - ErrIncompatibleSpotInstanceTypeInf = "clusterconfig.incompatible_spot_instance_type_inf" - ErrSpotPriceGreaterThanTargetOnDemand = "clusterconfig.spot_price_greater_than_target_on_demand" - ErrSpotPriceGreaterThanMaxPrice = "clusterconfig.spot_price_greater_than_max_price" - ErrInstanceTypeNotSupported = "clusterconfig.instance_type_not_supported" - ErrARMInstancesNotSupported = "clusterconfig.arm_instances_not_supported" - ErrAtLeastOneInstanceDistribution = "clusterconfig.at_least_one_instance_distribution" - ErrNoCompatibleSpotInstanceFound = "clusterconfig.no_compatible_spot_instance_found" - ErrConfiguredWhenSpotIsNotEnabled = "clusterconfig.configured_when_spot_is_not_enabled" - ErrOnDemandBaseCapacityGreaterThanMax = "clusterconfig.on_demand_base_capacity_greater_than_max" - ErrConfigCannotBeChangedOnUpdate = "clusterconfig.config_cannot_be_changed_on_update" - ErrInvalidAvailabilityZone = "clusterconfig.invalid_availability_zone" - ErrAvailabilityZoneSpecifiedTwice = "clusterconfig.availability_zone_specified_twice" - ErrUnsupportedAvailabilityZone = "clusterconfig.unsupported_availability_zone" - ErrNotEnoughValidDefaultAvailibilityZones = "clusterconfig.not_enough_valid_default_availability_zones" - ErrNoNATGatewayWithSubnets = "clusterconfig.no_nat_gateway_with_subnets" - ErrSpecifyOneOrNone = "clusterconfig.specify_one_or_none" - ErrDependentFieldMustBeSpecified = "clusterconfig.dependent_field_must_be_specified" - ErrFieldConfigurationDependentOnCondition = "clusterconfig.field_configuration_dependent_on_condition" - ErrDidNotMatchStrictS3Regex = "clusterconfig.did_not_match_strict_s3_regex" - ErrNATRequiredWithPrivateSubnetVisibility = "clusterconfig.nat_required_with_private_subnet_visibility" - ErrS3RegionDiffersFromCluster = "clusterconfig.s3_region_differs_from_cluster" - ErrInvalidInstanceType = "clusterconfig.invalid_instance_type" - ErrIOPSNotSupported = "clusterconfig.iops_not_supported" - ErrIOPSTooLarge = "clusterconfig.iops_too_large" - ErrCantOverrideDefaultTag = "clusterconfig.cant_override_default_tag" - ErrSSLCertificateARNNotFound = "clusterconfig.ssl_certificate_arn_not_found" - ErrIAMPolicyARNNotFound = "clusterconfig.iam_policy_arn_not_found" - ErrProviderMismatch = "clusterconfig.provider_mismatch" + ErrInvalidRegion = "clusterconfig.invalid_region" + ErrNoNodeGroupSpecified = "clusterconfig.no_nodegroup_specified" + ErrMaxNumOfNodeGroupsReached = "clusterconfig.max_num_of_nodegroups_reached" + ErrDuplicateNodeGroupName = "clusterconfig.duplicate_nodegroup_name" + ErrInstanceTypeTooSmall = "clusterconfig.instance_type_too_small" + ErrMinInstancesGreaterThanMax = "clusterconfig.min_instances_greater_than_max" + ErrInstanceTypeNotSupportedInRegion = "clusterconfig.instance_type_not_supported_in_region" + ErrIncompatibleSpotInstanceTypeMemory = "clusterconfig.incompatible_spot_instance_type_memory" + ErrIncompatibleSpotInstanceTypeCPU = "clusterconfig.incompatible_spot_instance_type_cpu" + ErrIncompatibleSpotInstanceTypeGPU = "clusterconfig.incompatible_spot_instance_type_gpu" + ErrIncompatibleSpotInstanceTypeInf = "clusterconfig.incompatible_spot_instance_type_inf" + ErrSpotPriceGreaterThanTargetOnDemand = "clusterconfig.spot_price_greater_than_target_on_demand" + ErrSpotPriceGreaterThanMaxPrice = "clusterconfig.spot_price_greater_than_max_price" + ErrInstanceTypeNotSupported = "clusterconfig.instance_type_not_supported" + ErrARMInstancesNotSupported = "clusterconfig.arm_instances_not_supported" + ErrAtLeastOneInstanceDistribution = "clusterconfig.at_least_one_instance_distribution" + ErrNoCompatibleSpotInstanceFound = "clusterconfig.no_compatible_spot_instance_found" + ErrConfiguredWhenSpotIsNotEnabled = "clusterconfig.configured_when_spot_is_not_enabled" + ErrOnDemandBaseCapacityGreaterThanMax = "clusterconfig.on_demand_base_capacity_greater_than_max" + ErrConfigCannotBeChangedOnUpdate = "clusterconfig.config_cannot_be_changed_on_update" + ErrInvalidAvailabilityZone = "clusterconfig.invalid_availability_zone" + ErrAvailabilityZoneSpecifiedTwice = "clusterconfig.availability_zone_specified_twice" + ErrUnsupportedAvailabilityZone = "clusterconfig.unsupported_availability_zone" + ErrNotEnoughValidDefaultAvailibilityZones = "clusterconfig.not_enough_valid_default_availability_zones" + ErrNoNATGatewayWithSubnets = "clusterconfig.no_nat_gateway_with_subnets" + ErrSpecifyOneOrNone = "clusterconfig.specify_one_or_none" + ErrDependentFieldMustBeSpecified = "clusterconfig.dependent_field_must_be_specified" + ErrFieldConfigurationDependentOnCondition = "clusterconfig.field_configuration_dependent_on_condition" + ErrDidNotMatchStrictS3Regex = "clusterconfig.did_not_match_strict_s3_regex" + ErrNATRequiredWithPrivateSubnetVisibility = "clusterconfig.nat_required_with_private_subnet_visibility" + ErrS3RegionDiffersFromCluster = "clusterconfig.s3_region_differs_from_cluster" + ErrInvalidInstanceType = "clusterconfig.invalid_instance_type" + ErrIOPSNotSupported = "clusterconfig.iops_not_supported" + ErrIOPSTooLarge = "clusterconfig.iops_too_large" + ErrCantOverrideDefaultTag = "clusterconfig.cant_override_default_tag" + ErrSSLCertificateARNNotFound = "clusterconfig.ssl_certificate_arn_not_found" + ErrIAMPolicyARNNotFound = "clusterconfig.iam_policy_arn_not_found" + ErrProviderMismatch = "clusterconfig.provider_mismatch" + ErrGCPInvalidProjectID = "clusterconfig.gcp_invalid_project_id" ErrGCPProjectMustBeSpecified = "clusterconfig.gcp_project_must_be_specified" ErrGCPInvalidZone = "clusterconfig.gcp_invalid_zone" + ErrGCPNoNodePoolSpecified = "clusterconfig.gcp_no_nodepool_specified" + ErrGCPMaxNumOfNodePoolsReached = "clusterconfig.gcp_max_num_of_nodepools_reached" + ErrGCPDuplicateNodePoolName = "clusterconfig.gcp_duplicate_nodepool_name" ErrGCPInvalidInstanceType = "clusterconfig.gcp_invalid_instance_type" ErrGCPInvalidAcceleratorType = "clusterconfig.gcp_invalid_accelerator_type" ErrGCPIncompatibleInstanceTypeWithAccelerator = "clusterconfig.gcp_incompatible_instance_type_with_accelerator" @@ -78,6 +85,27 @@ func ErrorInvalidRegion(region string) error { }) } +func ErrorNoNodeGroupSpecified() error { + return errors.WithStack(&errors.Error{ + Kind: ErrNoNodeGroupSpecified, + Message: "no nodegroup was specified; please specify at least 1 nodegroup", + }) +} + +func ErrorMaxNumOfNodeGroupsReached(maxNodeGroups int64) error { + return errors.WithStack(&errors.Error{ + Kind: ErrMaxNumOfNodeGroupsReached, + Message: fmt.Sprintf("cannot have more than %d nodegroups", maxNodeGroups), + }) +} + +func ErrorDuplicateNodeGroupName(duplicateNgName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrDuplicateNodeGroupName, + Message: fmt.Sprintf("cannot have multiple nodegroups with the same name (%s)", duplicateNgName), + }) +} + func ErrorInstanceTypeTooSmall() error { return errors.WithStack(&errors.Error{ Kind: ErrInstanceTypeTooSmall, @@ -169,10 +197,10 @@ func ErrorOnDemandBaseCapacityGreaterThanMax(onDemandBaseCapacity int64, max int }) } -func ErrorConfigCannotBeChangedOnUpdate(configKey string, prevVal interface{}) error { +func ErrorConfigCannotBeChangedOnUpdate() error { return errors.WithStack(&errors.Error{ Kind: ErrConfigCannotBeChangedOnUpdate, - Message: fmt.Sprintf("modifying %s in a running cluster is not supported, please set %s to its previous value (%s)", configKey, configKey, s.UserStr(prevVal)), + Message: fmt.Sprintf("in a running cluster, only the %s and %s fields in the %s section can be modified", MinInstancesKey, MaxInstancesKey, NodeGroupsKey), }) } @@ -356,6 +384,27 @@ func ErrorGCPInvalidZone(zone string, suggestedZones ...string) error { }) } +func ErrorGCPNoNodePoolSpecified() error { + return errors.WithStack(&errors.Error{ + Kind: ErrGCPNoNodePoolSpecified, + Message: "no nodepool was specified; please specify at least 1 nodepool", + }) +} + +func ErrorGCPMaxNumOfNodePoolsReached(maxNodePools int64) error { + return errors.WithStack(&errors.Error{ + Kind: ErrGCPMaxNumOfNodePoolsReached, + Message: fmt.Sprintf("cannot have more than %d nodepools", maxNodePools), + }) +} + +func ErrorGCPDuplicateNodePoolName(duplicateNpName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrGCPDuplicateNodePoolName, + Message: fmt.Sprintf("cannot have multiple nodepools with the same name (%s)", duplicateNpName), + }) +} + func ErrorGCPInvalidInstanceType(instanceType string, suggestedInstanceTypes ...string) error { errorMessage := fmt.Sprintf("invalid instance type '%s'", instanceType) if len(suggestedInstanceTypes) == 1 { diff --git a/pkg/types/clusterstate/clusterstate.go b/pkg/types/clusterstate/clusterstate.go index 074104be9f..103f1821c9 100644 --- a/pkg/types/clusterstate/clusterstate.go +++ b/pkg/types/clusterstate/clusterstate.go @@ -30,9 +30,10 @@ import ( const ( controlPlaneTemplate = "eksctl-%s-cluster" - operatorTemplate = "eksctl-%s-nodegroup-ng-cortex-operator" - spotTemplate = "eksctl-%s-nodegroup-ng-cortex-worker-spot" - onDemandTemplate = "eksctl-%s-nodegroup-ng-cortex-worker-on-demand" + operatorTemplate = "eksctl-%s-nodegroup-cx-operator" + + spotTemplatePrefix = "eksctl-%s-nodegroup-cx-ws" + onDemandTemplatePrefix = "eksctl-%s-nodegroup-cx-wd" ) type ClusterState struct { @@ -146,6 +147,14 @@ func getStatus(statusMap map[string]string, controlPlane string, clusterName str return StatusCreateComplete, nil } + if all(allStatuses, cloudformation.StackStatusUpdateComplete) { + return StatusUpdateComplete, nil + } + + if all(allStatuses, cloudformation.StackStatusUpdateRollbackComplete) { + return StatusUpdateRollbackComplete, nil + } + if all(allStatuses, cloudformation.StackStatusDeleteComplete) { return StatusDeleteComplete, nil } @@ -162,18 +171,23 @@ func getStatus(statusMap map[string]string, controlPlane string, clusterName str return StatusCreateInProgress, nil } + if controlPlaneStatus == cloudformation.StackStatusCreateComplete && + all(nodeGroupStatuses, cloudformation.StackStatusCreateComplete, cloudformation.StackStatusUpdateComplete, cloudformation.StackStatusUpdateRollbackComplete) { + return StatusUpdateComplete, nil + } + return StatusNotFound, ErrorUnexpectedCloudFormationStatus(clusterName, region, statusMap) } func GetClusterState(awsClient *aws.Client, accessConfig *clusterconfig.AccessConfig) (*ClusterState, error) { controlPlaneStackName := fmt.Sprintf(controlPlaneTemplate, accessConfig.ClusterName) operatorStackName := fmt.Sprintf(operatorTemplate, accessConfig.ClusterName) - spotStackName := fmt.Sprintf(spotTemplate, accessConfig.ClusterName) - onDemandStackName := fmt.Sprintf(onDemandTemplate, accessConfig.ClusterName) + spotStackNamePrefix := fmt.Sprintf(spotTemplatePrefix, accessConfig.ClusterName) + onDemandStackNamePrefix := fmt.Sprintf(onDemandTemplatePrefix, accessConfig.ClusterName) - nodeGroupStackNamesSet := strset.New(operatorStackName, spotStackName, onDemandStackName) + nodeGroupStackPrefixesSet := strset.New(operatorStackName, spotStackNamePrefix, onDemandStackNamePrefix) - stackSummaries, err := awsClient.ListEKSStacks(controlPlaneStackName, nodeGroupStackNamesSet) + stackSummaries, err := awsClient.ListEKSStacks(controlPlaneStackName, nodeGroupStackPrefixesSet) if err != nil { return nil, errors.Wrap(err, "unable to get cluster state from cloudformation") } diff --git a/pkg/types/clusterstate/errors.go b/pkg/types/clusterstate/errors.go index 6bc22ffa43..4ad54773d8 100644 --- a/pkg/types/clusterstate/errors.go +++ b/pkg/types/clusterstate/errors.go @@ -28,6 +28,7 @@ const ( ErrClusterCreateFailed = "clusterstatus.cluster_create_failed" ErrClusterCreateFailedTimeout = "clusterstatus.cluster_create_failed_timeout" ErrClusterAlreadyCreated = "clusterstatus.cluster_already_created" + ErrClusterAlreadyUpdated = "clusterstatus.cluster_already_updated" ErrClusterDownInProgress = "clusterstatus.cluster_down_in_progress" ErrClusterAlreadyDeleted = "clusterstatus.cluster_already_deleted" ErrClusterDeleteFailed = "clusterstatus.cluster_delete_failed" @@ -69,6 +70,13 @@ func ErrorClusterAlreadyCreated(clusterName string, region string) error { }) } +func ErrorClusterAlreadyUpdated(clusterName string, region string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrClusterAlreadyUpdated, + Message: fmt.Sprintf("a cluster named \"%s\" already created and updated in %s", clusterName, region), + }) +} + func ErrorClusterDownInProgress(clusterName string, region string) error { return errors.WithStack(&errors.Error{ Kind: ErrClusterDownInProgress, diff --git a/pkg/types/clusterstate/status.go b/pkg/types/clusterstate/status.go index 2c896f0237..f3b53aaae6 100644 --- a/pkg/types/clusterstate/status.go +++ b/pkg/types/clusterstate/status.go @@ -19,14 +19,16 @@ package clusterstate type Status string const ( - StatusNotFound Status = "not_found" - StatusCreateInProgress Status = "create_in_progress" - StatusCreateFailed Status = "create_failed" - StatusCreateFailedTimedOut Status = "create_failed_timed_out" - StatusCreateComplete Status = "create_complete" - StatusDeleteInProgress Status = "delete_in_progress" - StatusDeleteComplete Status = "delete_complete" - StatusDeleteFailed Status = "delete_failed" + StatusNotFound Status = "not_found" + StatusCreateInProgress Status = "create_in_progress" + StatusCreateFailed Status = "create_failed" + StatusCreateFailedTimedOut Status = "create_failed_timed_out" + StatusCreateComplete Status = "create_complete" + StatusUpdateComplete Status = "update_complete" + StatusUpdateRollbackComplete Status = "update_rollback_complete" + StatusDeleteInProgress Status = "delete_in_progress" + StatusDeleteComplete Status = "delete_complete" + StatusDeleteFailed Status = "delete_failed" ) func AssertClusterStatus(clusterName string, region string, status Status, allowedStatuses ...Status) error { @@ -47,6 +49,10 @@ func AssertClusterStatus(clusterName string, region string, status Status, allow return ErrorClusterCreateFailedTimeout(clusterName, region) case StatusCreateComplete: return ErrorClusterAlreadyCreated(clusterName, region) + case StatusUpdateComplete: + return ErrorClusterAlreadyUpdated(clusterName, region) + case StatusUpdateRollbackComplete: + return ErrorClusterAlreadyUpdated(clusterName, region) case StatusDeleteInProgress: return ErrorClusterDownInProgress(clusterName, region) case StatusDeleteComplete: diff --git a/test/apis/pytorch/image-classifier-resnet50/expectations.yaml b/test/apis/pytorch/image-classifier-resnet50/expectations.yaml new file mode 100644 index 0000000000..5e1d38f9ae --- /dev/null +++ b/test/apis/pytorch/image-classifier-resnet50/expectations.yaml @@ -0,0 +1,5 @@ +# this file is used for testing purposes only + +response: + content_type: "json" + expected: ["tabby", "Egyptian_cat", "tiger_cat", "tiger", "plastic_bag"] diff --git a/test/e2e/README.md b/test/e2e/README.md index be8638f256..3691b520f4 100644 --- a/test/e2e/README.md +++ b/test/e2e/README.md @@ -65,6 +65,10 @@ pytest test/e2e/tests --aws-env --gcp-env It is possible to skip GPU tests by passing the `--skip-gpus` flag to the pytest command. +### Skip Inferentia Tests + +It is possible to skip Inferentia tests by passing the `--skip-infs` flag to the pytest command. + ## Configuration It is possible to configure the behaviour of the tests by defining environment variables or a `.env` file at the project @@ -72,8 +76,8 @@ directory. ```dotenv # .env file -CORTEX_TEST_REALTIME_DEPLOY_TIMEOUT=60 -CORTEX_TEST_BATCH_DEPLOY_TIMEOUT=30 +CORTEX_TEST_REALTIME_DEPLOY_TIMEOUT=120 +CORTEX_TEST_BATCH_DEPLOY_TIMEOUT=60 CORTEX_TEST_BATCH_JOB_TIMEOUT=120 CORTEX_TEST_BATCH_S3_PATH=s3:///test/jobs ``` diff --git a/test/e2e/e2e/tests.py b/test/e2e/e2e/tests.py index ffdf089285..208d5c30d6 100644 --- a/test/e2e/e2e/tests.py +++ b/test/e2e/e2e/tests.py @@ -38,9 +38,11 @@ def delete_apis(client: cx.Client, api_names: List[str]): client.delete_api(name) -def test_realtime_api(client: cx.Client, api: str, timeout: int = None): +def test_realtime_api( + client: cx.Client, api: str, timeout: int = None, api_config_name: str = "cortex.yaml" +): api_dir = TEST_APIS_DIR / api - with open(str(api_dir / "cortex.yaml")) as f: + with open(str(api_dir / api_config_name)) as f: api_specs = yaml.safe_load(f) expectations = None @@ -79,9 +81,10 @@ def test_batch_api( deploy_timeout: int = None, job_timeout: int = None, retry_attempts: int = 0, + api_config_name: str = "cortex.yaml", ): api_dir = TEST_APIS_DIR / api - with open(str(api_dir / "cortex.yaml")) as f: + with open(str(api_dir / api_config_name)) as f: api_specs = yaml.safe_load(f) assert len(api_specs) == 1 diff --git a/test/e2e/tests/aws/test_realtime.py b/test/e2e/tests/aws/test_realtime.py index af334697f0..1b4bb7e225 100644 --- a/test/e2e/tests/aws/test_realtime.py +++ b/test/e2e/tests/aws/test_realtime.py @@ -20,6 +20,7 @@ TEST_APIS = ["pytorch/iris-classifier", "onnx/iris-classifier", "tensorflow/iris-classifier"] TEST_APIS_GPU = ["pytorch/text-generator", "tensorflow/text-generator"] +TEST_APIS_INF = ["pytorch/image-classifier-resnet50"] @pytest.mark.usefixtures("client") @@ -40,3 +41,18 @@ def test_realtime_api_gpu(config: Dict, client: cx.Client, api: str): e2e.tests.test_realtime_api( client=client, api=api, timeout=config["global"]["realtime_deploy_timeout"] ) + + +@pytest.mark.usefixtures("client") +@pytest.mark.parametrize("api", TEST_APIS_INF) +def test_realtime_api_inf(config: Dict, client: cx.Client, api: str): + skip_infs = config["global"].get("skip_infs", False) + if skip_infs: + pytest.skip("--skip-infs flag detected, skipping Inferentia tests") + + e2e.tests.test_realtime_api( + client=client, + api=api, + timeout=config["global"]["realtime_deploy_timeout"], + api_config_name="cortex_inf.yaml", + ) diff --git a/test/e2e/tests/conftest.py b/test/e2e/tests/conftest.py index 5bf2fa10e3..3d86e1753e 100644 --- a/test/e2e/tests/conftest.py +++ b/test/e2e/tests/conftest.py @@ -54,6 +54,11 @@ def pytest_addoption(parser): action="store_true", help="skip GPU tests", ) + parser.addoption( + "--skip-infs", + action="store_true", + help="skip Inferentia tests", + ) def pytest_configure(config): @@ -79,6 +84,7 @@ def pytest_configure(config): "batch_deploy_timeout": int(os.environ.get("CORTEX_TEST_BATCH_DEPLOY_TIMEOUT", 30)), "batch_job_timeout": int(os.environ.get("CORTEX_TEST_BATCH_JOB_TIMEOUT", 200)), "skip_gpus": config.getoption("--skip-gpus"), + "skip_infs": config.getoption("--skip-infs"), }, }