diff --git a/pkg/lib/aws/ec2.go b/pkg/lib/aws/ec2.go index ecdaf73312..0e139d7c49 100644 --- a/pkg/lib/aws/ec2.go +++ b/pkg/lib/aws/ec2.go @@ -224,3 +224,108 @@ func (c *Client) ListSupportedAvailabilityZones(instanceType string, instanceTyp return strset.Intersection(zoneSets...), nil } + +func (c *Client) ListElasticIPs() ([]string, error) { + addresses, err := c.EC2().DescribeAddresses(&ec2.DescribeAddressesInput{}) + if err != nil { + return nil, errors.WithStack(err) + } + + addressesList := []string{} + if addresses != nil { + for _, address := range addresses.Addresses { + if address != nil && address.PublicIp != nil { + addressesList = append(addressesList, *address.PublicIp) + } + } + } + + return addressesList, nil +} + +func (c *Client) ListInternetGateways() ([]string, error) { + gatewaysList := []string{} + err := c.EC2().DescribeInternetGatewaysPages(&ec2.DescribeInternetGatewaysInput{}, func(output *ec2.DescribeInternetGatewaysOutput, lastPage bool) bool { + if output == nil { + return false + } + for _, gateway := range output.InternetGateways { + if gateway != nil && gateway.InternetGatewayId != nil { + gatewaysList = append(gatewaysList, *gateway.InternetGatewayId) + } + } + + return true + }) + if err != nil { + return nil, errors.WithStack(err) + } + + return gatewaysList, nil +} + +func (c *Client) DescribeNATGateways() ([]ec2.NatGateway, error) { + var gateways []ec2.NatGateway + err := c.EC2().DescribeNatGatewaysPages(&ec2.DescribeNatGatewaysInput{}, func(output *ec2.DescribeNatGatewaysOutput, lastPage bool) bool { + if output == nil { + return false + } + for _, gateway := range output.NatGateways { + if gateway == nil { + continue + } + gateways = append(gateways, *gateway) + } + + return true + }) + if err != nil { + return nil, errors.WithStack(err) + } + + return gateways, nil +} + +func (c *Client) DescribeSubnets() ([]ec2.Subnet, error) { + var subnets []ec2.Subnet + err := c.EC2().DescribeSubnetsPages(&ec2.DescribeSubnetsInput{}, func(output *ec2.DescribeSubnetsOutput, lastPage bool) bool { + if output == nil { + return false + } + for _, subnet := range output.Subnets { + if subnet == nil { + continue + } + subnets = append(subnets, *subnet) + } + + return true + }) + if err != nil { + return nil, errors.WithStack(err) + } + + return subnets, nil +} + +func (c *Client) DescribeVpcs() ([]ec2.Vpc, error) { + var vpcs []ec2.Vpc + err := c.EC2().DescribeVpcsPages(&ec2.DescribeVpcsInput{}, func(output *ec2.DescribeVpcsOutput, lastPage bool) bool { + if output == nil { + return false + } + for _, vpc := range output.Vpcs { + if vpc == nil { + continue + } + vpcs = append(vpcs, *vpc) + } + + return true + }) + if err != nil { + return nil, errors.WithStack(err) + } + + return vpcs, nil +} diff --git a/pkg/lib/aws/errors.go b/pkg/lib/aws/errors.go index b43332d082..3493c2df73 100644 --- a/pkg/lib/aws/errors.go +++ b/pkg/lib/aws/errors.go @@ -43,8 +43,19 @@ const ( ErrDashboardHeightOutOfRange = "aws.dashboard_height_out_of_range" ErrRegionNotConfigured = "aws.region_not_configured" ErrUnableToFindCredentials = "aws.unable_to_find_credentials" + ErrNATGatewayLimitExceeded = "aws.nat_gateway_limit_exceeded" + ErrEIPLimitExceeded = "aws.eip_limit_exceeded" + ErrInternetGatewayLimitExceeded = "aws.internet_gateway_limit_exceeded" + ErrVPCLimitExceeded = "aws.vpc_limit_exceeded" ) +func IsAWSError(err error) bool { + if _, ok := errors.CauseOrSelf(err).(awserr.Error); ok { + return true + } + return false +} + func IsNotFoundErr(err error) bool { return IsErrCode(err, "NotFound") } @@ -196,3 +207,35 @@ func ErrorUnableToFindCredentials() error { Message: "unable to find aws credentials; instructions about configuring aws credentials can be found at https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html", }) } + +func ErrorNATGatewayLimitExceeded(currentLimit, additionalQuotaRequired int, availabilityZones []string, region string) error { + url := "https://console.aws.amazon.com/servicequotas/home?#!/services/vpc/quotas" + return errors.WithStack(&errors.Error{ + Kind: ErrNATGatewayLimitExceeded, + Message: fmt.Sprintf("NAT gateway limit of %d exceeded in availability zones %s of region %s; remove some of the existing NAT gateways or increase your quota for NAT gateways by at least %d here: %s (if your request was recently approved, please allow ~30 minutes for AWS to reflect this change)", currentLimit, s.StrsAnd(availabilityZones), region, additionalQuotaRequired, url), + }) +} + +func ErrorEIPLimitExceeded(currentLimit, additionalQuotaRequired int, region string) error { + url := "https://console.aws.amazon.com/servicequotas/home?#!/services/ec2/quotas" + return errors.WithStack(&errors.Error{ + Kind: ErrEIPLimitExceeded, + Message: fmt.Sprintf("elastic IPs limit of %d exceeded in region %s; remove some of the existing elastic IPs or increase your quota for elastic IPs by at least %d here: %s (if your request was recently approved, please allow ~30 minutes for AWS to reflect this change)", currentLimit, region, additionalQuotaRequired, url), + }) +} + +func ErrorInternetGatewayLimitExceeded(currentLimit, additionalQuotaRequired int, region string) error { + url := "https://console.aws.amazon.com/servicequotas/home?#!/services/vpc/quotas" + return errors.WithStack(&errors.Error{ + Kind: ErrInternetGatewayLimitExceeded, + Message: fmt.Sprintf("internet gateway limit of %d exceeded in region %s; remove some of the existing internet gateways or increase your quota for internet gateways by at least %d here: %s (if your request was recently approved, please allow ~30 minutes for AWS to reflect this change)", currentLimit, region, additionalQuotaRequired, url), + }) +} + +func ErrorVPCLimitExceeded(currentLimit, additionalQuotaRequired int, region string) error { + url := "https://console.aws.amazon.com/servicequotas/home?#!/services/vpc/quotas" + return errors.WithStack(&errors.Error{ + Kind: ErrVPCLimitExceeded, + Message: fmt.Sprintf("VPC limit of %d exceeded in region %s; remove some of the existing VPCs or increase your quota for VPCs by at least %d here: %s (if your request was recently approved, please allow ~30 minutes for AWS to reflect this change)", currentLimit, region, additionalQuotaRequired, url), + }) +} diff --git a/pkg/lib/aws/servicequotas.go b/pkg/lib/aws/servicequotas.go index b51f223d6f..df9715b15a 100644 --- a/pkg/lib/aws/servicequotas.go +++ b/pkg/lib/aws/servicequotas.go @@ -31,6 +31,13 @@ var _instanceCategoryRegex = regexp.MustCompile(`[a-zA-Z]+`) var _standardInstanceCategories = strset.New("a", "c", "d", "h", "i", "m", "r", "t", "z") var _knownInstanceCategories = strset.Union(_standardInstanceCategories, strset.New("p", "g", "inf", "x", "f")) +const ( + _elasticIPsQuotaCode = "L-0263D0A3" + _internetGatewayQuotaCode = "L-A4707A72" + _natGatewayQuotaCode = "L-FE5A380F" + _vpcQuotaCode = "L-F678F1CE" +) + func (c *Client) VerifyInstanceQuota(instanceType string, requiredOnDemandInstances int64, requiredSpotInstances int64) error { if requiredOnDemandInstances == 0 && requiredSpotInstances == 0 { return nil @@ -103,3 +110,146 @@ func (c *Client) VerifyInstanceQuota(instanceType string, requiredOnDemandInstan return nil } + +func (c *Client) VerifyNetworkQuotas(requiredInternetGateways int, natGatewayRequired bool, highlyAvailableNATGateway bool, requiredVPCs int, availabilityZones strset.Set) error { + quotaCodeToValueMap := map[string]int{ + _elasticIPsQuotaCode: 0, // elastic IP quota code + _internetGatewayQuotaCode: 0, // internet gw quota code + _natGatewayQuotaCode: 0, // nat gw quota code + _vpcQuotaCode: 0, // vpc quota code + } + + err := c.ServiceQuotas().ListServiceQuotasPages( + &servicequotas.ListServiceQuotasInput{ + ServiceCode: aws.String("ec2"), + }, + func(page *servicequotas.ListServiceQuotasOutput, lastPage bool) bool { + if page == nil { + return false + } + for _, quota := range page.Quotas { + if quota == nil || quota.QuotaCode == nil || quota.Value == nil { + continue + } + if _, ok := quotaCodeToValueMap[*quota.QuotaCode]; ok { + quotaCodeToValueMap[*quota.QuotaCode] = int(*quota.Value) + return false + } + } + return true + }, + ) + if err != nil { + return errors.WithStack(err) + } + + err = c.ServiceQuotas().ListServiceQuotasPages( + &servicequotas.ListServiceQuotasInput{ + ServiceCode: aws.String("vpc"), + }, + func(page *servicequotas.ListServiceQuotasOutput, lastPage bool) bool { + if page == nil { + return false + } + for _, quota := range page.Quotas { + if quota == nil || quota.QuotaCode == nil || quota.Value == nil { + continue + } + if _, ok := quotaCodeToValueMap[*quota.QuotaCode]; ok { + quotaCodeToValueMap[*quota.QuotaCode] = int(*quota.Value) + } + } + return true + }, + ) + if err != nil { + return errors.WithStack(err) + } + + // check internet GW quota + if requiredInternetGateways > 0 { + internetGatewaysInUse, err := c.ListInternetGateways() + if err != nil { + return err + } + if quotaCodeToValueMap[_internetGatewayQuotaCode]-len(internetGatewaysInUse)-requiredInternetGateways < 0 { + additionalQuotaRequired := len(internetGatewaysInUse) + requiredInternetGateways - quotaCodeToValueMap[_internetGatewayQuotaCode] + return ErrorInternetGatewayLimitExceeded(quotaCodeToValueMap[_internetGatewayQuotaCode], additionalQuotaRequired, c.Region) + } + } + + if natGatewayRequired { + // get NAT GW in use per selected AZ + natGateways, err := c.DescribeNATGateways() + if err != nil { + return err + } + subnets, err := c.DescribeSubnets() + if err != nil { + return err + } + azToGatewaysInUse := map[string]int{} + for _, natGateway := range natGateways { + if natGateway.SubnetId == nil { + continue + } + for _, subnet := range subnets { + if subnet.SubnetId == nil || subnet.AvailabilityZone == nil { + continue + } + if !availabilityZones.Has(*subnet.AvailabilityZone) { + continue + } + if *subnet.SubnetId == *natGateway.SubnetId { + azToGatewaysInUse[*subnet.AvailabilityZone]++ + } + } + } + // check NAT GW quota + numOfExhaustedNATGatewayAZs := 0 + azsWithQuotaDeficit := []string{} + for az, numActiveGatewaysOnAZ := range azToGatewaysInUse { + // -1 comes from the NAT gateway we require per AZ + azDeficit := quotaCodeToValueMap[_natGatewayQuotaCode] - numActiveGatewaysOnAZ - 1 + if azDeficit < 0 { + numOfExhaustedNATGatewayAZs++ + azsWithQuotaDeficit = append(azsWithQuotaDeficit, az) + } + } + if (highlyAvailableNATGateway && numOfExhaustedNATGatewayAZs > 0) || (!highlyAvailableNATGateway && numOfExhaustedNATGatewayAZs == len(availabilityZones)) { + return ErrorNATGatewayLimitExceeded(quotaCodeToValueMap[_natGatewayQuotaCode], 1, azsWithQuotaDeficit, c.Region) + } + } + + // check EIP quota + if natGatewayRequired { + elasticIPsInUse, err := c.ListElasticIPs() + if err != nil { + return err + } + var requiredElasticIPs int + if highlyAvailableNATGateway { + requiredElasticIPs = len(availabilityZones) + } else { + requiredElasticIPs = 1 + } + if quotaCodeToValueMap[_elasticIPsQuotaCode]-len(elasticIPsInUse)-requiredElasticIPs < 0 { + additionalQuotaRequired := len(elasticIPsInUse) + requiredElasticIPs - quotaCodeToValueMap[_elasticIPsQuotaCode] + return ErrorEIPLimitExceeded(quotaCodeToValueMap[_elasticIPsQuotaCode], additionalQuotaRequired, c.Region) + } + } + + // check VPC quota + if requiredVPCs > 0 { + vpcs, err := c.DescribeVpcs() + if err != nil { + return err + } + if quotaCodeToValueMap[_vpcQuotaCode]-len(vpcs)-requiredVPCs < 0 { + additionalQuotaRequired := len(vpcs) + requiredVPCs - quotaCodeToValueMap[_vpcQuotaCode] + return ErrorVPCLimitExceeded(quotaCodeToValueMap[_vpcQuotaCode], additionalQuotaRequired, c.Region) + } + } + + return nil +} diff --git a/pkg/types/clusterconfig/cluster_config_aws.go b/pkg/types/clusterconfig/cluster_config_aws.go index ad4013ed11..98f3460173 100644 --- a/pkg/types/clusterconfig/cluster_config_aws.go +++ b/pkg/types/clusterconfig/cluster_config_aws.go @@ -25,7 +25,6 @@ import ( "regexp" "strings" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/aws" @@ -766,8 +765,8 @@ func (cc *Config) Validate(awsClient *aws.Client) error { if err := awsClient.VerifyInstanceQuota(primaryInstanceType, cc.MaxPossibleOnDemandInstances(), cc.MaxPossibleSpotInstances()); err != nil { // Skip AWS errors, since some regions (e.g. eu-north-1) do not support this API - if _, ok := errors.CauseOrSelf(err).(awserr.Error); !ok { - return errors.Wrap(err, InstanceTypeKey) + if !aws.IsAWSError(err) { + return err } } @@ -793,6 +792,17 @@ func (cc *Config) Validate(awsClient *aws.Client) error { } } + var requiredVPCs int + if len(cc.Subnets) == 0 { + requiredVPCs = 1 + } + if err := awsClient.VerifyNetworkQuotas(1, cc.NATGateway != NoneNATGateway, cc.NATGateway == HighlyAvailableNATGateway, requiredVPCs, strset.FromSlice(cc.AvailabilityZones)); err != nil { + // Skip AWS errors, since some regions (e.g. eu-north-1) do not support this API + if !aws.IsAWSError(err) { + return err + } + } + if cc.Spot != nil && *cc.Spot { cc.FillEmptySpotFields()