Skip to content

Support GCP instances with multiple GPUs #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/cmd/cluster_gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func createGKECluster(clusterConfig *clusterconfig.GCPConfig, gcpClient *gcp.Cli

if clusterConfig.AcceleratorType != nil {
accelerators = append(accelerators, &containerpb.AcceleratorConfig{
AcceleratorCount: 1,
AcceleratorCount: *clusterConfig.AcceleratorsPerInstance,
AcceleratorType: *clusterConfig.AcceleratorType,
})
nodeLabels["nvidia.com/gpu"] = "present"
Expand Down
3 changes: 3 additions & 0 deletions docs/clusters/gcp/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ max_instances: 5
# GPU to attach to your instance (optional)
# accelerator_type: nvidia-tesla-t4

# the number of GPUs to attach to each instance (optional)
# accelerators_per_instance: 1

# the name of the network in which to create your cluster
# network: default

Expand Down
13 changes: 13 additions & 0 deletions pkg/lib/configreader/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,11 +1130,24 @@ func setField(val interface{}, destStruct interface{}, fieldName string) error {
debug.Ppg(destStruct)
return errors.Wrap(ErrorCannotSetStructField(), fieldName)
}

if val == nil {
// Check for nil-able types
if v.Kind() == reflect.Chan || v.Kind() == reflect.Func || v.Kind() == reflect.Interface || v.Kind() == reflect.Map || v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice {
v.Set(reflect.Zero(v.Type()))
return nil
}
debug.Ppg(val)
debug.Ppg(destStruct)
return errors.Wrap(ErrorCannotSetStructField(), fieldName)
}

if !reflect.ValueOf(val).Type().AssignableTo(v.Type()) {
debug.Ppg(val)
debug.Ppg(destStruct)
return errors.Wrap(ErrorCannotSetStructField(), fieldName)
}

v.Set(reflect.ValueOf(val))
return nil
}
Expand Down
29 changes: 29 additions & 0 deletions pkg/types/clusterconfig/cluster_config_gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type GCPConfig struct {
Zone *string `json:"zone" yaml:"zone"`
InstanceType *string `json:"instance_type" yaml:"instance_type"`
AcceleratorType *string `json:"accelerator_type" yaml:"accelerator_type"`
AcceleratorsPerInstance *int64 `json:"accelerators_per_instance" yaml:"accelerators_per_instance"`
Network *string `json:"network" yaml:"network"`
Subnet *string `json:"subnet" yaml:"subnet"`
APILoadBalancerScheme LoadBalancerScheme `json:"api_load_balancer_scheme" yaml:"api_load_balancer_scheme"`
Expand Down Expand Up @@ -138,6 +139,20 @@ var UserGCPValidation = &cr.StructValidation{
AllowExplicitNull: true,
},
},
{
StructField: "AcceleratorsPerInstance",
Int64PtrValidation: &cr.Int64PtrValidation{
AllowExplicitNull: true,
},
DefaultDependentFields: []string{"AcceleratorType"},
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
acceleratorType := vals[0].(*string)
if acceleratorType == nil {
return nil
}
return pointer.Int64(1)
},
},
{
StructField: "Network",
StringPtrValidation: &cr.StringPtrValidation{
Expand Down Expand Up @@ -326,7 +341,14 @@ func (cc *GCPConfig) Validate(GCP *gcp.Client) error {
return ErrorGCPInvalidInstanceType(*cc.InstanceType, instanceTypes...)
}

if cc.AcceleratorType == nil && cc.AcceleratorsPerInstance != nil {
return ErrorDependentFieldMustBeSpecified(AcceleratorsPerInstanceKey, AcceleratorTypeKey)
}

if cc.AcceleratorType != nil {
if cc.AcceleratorsPerInstance == nil {
return ErrorDependentFieldMustBeSpecified(AcceleratorTypeKey, AcceleratorsPerInstanceKey)
}
if validAccelerator, err := GCP.IsAcceleratorTypeAvailable(*cc.AcceleratorType, *cc.Zone); err != nil {
return err
} else if !validAccelerator {
Expand Down Expand Up @@ -517,6 +539,9 @@ func (cc *GCPConfig) UserTable() table.KeyValuePairs {
if cc.AcceleratorType != nil {
items.Add(AcceleratorTypeUserKey, *cc.AcceleratorType)
}
if cc.AcceleratorsPerInstance != nil {
items.Add(AcceleratorsPerInstanceUserKey, *cc.AcceleratorsPerInstance)
}
if cc.Network != nil {
items.Add(NetworkUserKey, *cc.Network)
}
Expand Down Expand Up @@ -554,6 +579,10 @@ func (cc *GCPConfig) TelemetryEvent() map[string]interface{} {
event["accelerator_type._is_defined"] = true
event["accelerator_type"] = *cc.AcceleratorType
}
if cc.AcceleratorsPerInstance != nil {
event["accelerators_per_instance._is_defined"] = true
event["accelerators_per_instance"] = *cc.AcceleratorsPerInstance
}
if cc.Network != nil {
event["network._is_defined"] = true
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/types/clusterconfig/config_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
ProviderKey = "provider"
InstanceTypeKey = "instance_type"
AcceleratorTypeKey = "accelerator_type"
AcceleratorsPerInstanceKey = "accelerators_per_instance"
NetworkKey = "network"
SubnetKey = "subnet"
MinInstancesKey = "min_instances"
Expand Down Expand Up @@ -83,6 +84,7 @@ const (
SpotUserKey = "use spot instances"
InstanceTypeUserKey = "instance type"
AcceleratorTypeUserKey = "accelerator type"
AcceleratorsPerInstanceUserKey = "accelerators per instance"
NetworkUserKey = "network"
SubnetUserKey = "subnet"
MinInstancesUserKey = "min instances"
Expand Down
8 changes: 8 additions & 0 deletions pkg/types/clusterconfig/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const (
ErrNotEnoughValidDefaultAvailibilityZones = "clusterconfig.not_enough_valid_default_availability_zones"
ErrNoNATGatewayWithSubnets = "clusterconfig.no_nat_gateway_with_subnets"
ErrSpecifyOneOrNone = "clusterconfig.specify_one_or_none"
ErrDependentFieldMustBeSpecified = "clusterconfig.dependent_field_must_be_specified"
ErrDidNotMatchStrictS3Regex = "clusterconfig.did_not_match_strict_s3_regex"
ErrNATRequiredWithPrivateSubnetVisibility = "clusterconfig.nat_required_with_private_subnet_visibility"
ErrS3RegionDiffersFromCluster = "clusterconfig.s3_region_differs_from_cluster"
Expand Down Expand Up @@ -241,6 +242,13 @@ func ErrorSpecifyOneOrNone(fieldName1 string, fieldName2 string, fieldNames ...s
})
}

func ErrorDependentFieldMustBeSpecified(configuredField string, dependencyField string) error {
return errors.WithStack(&errors.Error{
Kind: ErrDependentFieldMustBeSpecified,
Message: fmt.Sprintf("%s must be specified when configuring %s", dependencyField, configuredField),
})
}

func ErrorDidNotMatchStrictS3Regex() error {
return errors.WithStack(&errors.Error{
Kind: ErrDidNotMatchStrictS3Regex,
Expand Down