diff --git a/docs/applications/resources/models.md b/docs/applications/resources/models.md index 9f7773b9a3..b2d067f063 100644 --- a/docs/applications/resources/models.md +++ b/docs/applications/resources/models.md @@ -41,11 +41,21 @@ Train custom TensorFlow models at scale. start_delay_secs: # start evaluating after waiting for this many seconds (default: 120) throttle_secs: # do not re-evaluate unless the last evaluation was started at least this many seconds ago (default: 600) - compute: + compute: # Resources for training and evaluations steps (TensorFlow) cpu: # CPU request (default: Null) mem: # memory request (default: Null) gpu: # GPU request (default: Null) + dataset_compute: # Resources for constructing training dataset (Spark) + executors: # number of spark executors (default: 1) + driver_cpu: # CPU request for spark driver (default: 1) + driver_mem: # memory request for spark driver (default: 500Mi) + driver_mem_overhead: # off-heap (non-JVM) memory allocated to the driver (overrides mem_overhead_factor) (default: min[driver_mem * 0.4, 384Mi]) + executor_cpu: # CPU request for each spark executor (default: 1) + executor_mem: # memory request for each spark executor (default: 500Mi) + executor_mem_overhead: # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi]) + mem_overhead_factor: # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4) + tags: : # arbitrary key/value pairs to attach to the resource (optional) ... diff --git a/pkg/operator/api/userconfig/aggregates.go b/pkg/operator/api/userconfig/aggregates.go index 74643cccf8..21b8a0962b 100644 --- a/pkg/operator/api/userconfig/aggregates.go +++ b/pkg/operator/api/userconfig/aggregates.go @@ -51,7 +51,7 @@ var aggregateValidation = &cr.StructValidation{ }, }, inputValuesFieldValidation, - sparkComputeFieldValidation, + sparkComputeFieldValidation("Compute"), tagsFieldValidation, typeFieldValidation, }, diff --git a/pkg/operator/api/userconfig/compute.go b/pkg/operator/api/userconfig/compute.go index 0ecdae06ce..b9816e383c 100644 --- a/pkg/operator/api/userconfig/compute.go +++ b/pkg/operator/api/userconfig/compute.go @@ -45,84 +45,88 @@ type SparkCompute struct { MemOverheadFactor *float64 `json:"mem_overhead_factor" yaml:"mem_overhead_factor"` } -var sparkComputeFieldValidation = &cr.StructFieldValidation{ - StructField: "Compute", - StructValidation: &cr.StructValidation{ - StructFieldValidations: []*cr.StructFieldValidation{ - { - StructField: "Executors", - Int32Validation: &cr.Int32Validation{ - Default: 1, - GreaterThan: pointer.Int32(0), - }, +var sparkComputeStructValidation = &cr.StructValidation{ + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "Executors", + Int32Validation: &cr.Int32Validation{ + Default: 1, + GreaterThan: pointer.Int32(0), }, - { - StructField: "DriverCPU", - StringValidation: &cr.StringValidation{ - Default: "1", - }, - Parser: QuantityParser(&QuantityValidation{ - Min: k8sresource.MustParse("1"), - }), + }, + { + StructField: "DriverCPU", + StringValidation: &cr.StringValidation{ + Default: "1", }, - { - StructField: "ExecutorCPU", - StringValidation: &cr.StringValidation{ - Default: "1", - }, - Parser: QuantityParser(&QuantityValidation{ - Min: k8sresource.MustParse("1"), - Int: true, - }), + Parser: QuantityParser(&QuantityValidation{ + Min: k8sresource.MustParse("1"), + }), + }, + { + StructField: "ExecutorCPU", + StringValidation: &cr.StringValidation{ + Default: "1", }, - { - StructField: "DriverMem", - StringValidation: &cr.StringValidation{ - Default: "500Mi", - }, - Parser: QuantityParser(&QuantityValidation{ - Min: k8sresource.MustParse("500Mi"), - }), + Parser: QuantityParser(&QuantityValidation{ + Min: k8sresource.MustParse("1"), + Int: true, + }), + }, + { + StructField: "DriverMem", + StringValidation: &cr.StringValidation{ + Default: "500Mi", }, - { - StructField: "ExecutorMem", - StringValidation: &cr.StringValidation{ - Default: "500Mi", - }, - Parser: QuantityParser(&QuantityValidation{ - Min: k8sresource.MustParse("500Mi"), - }), + Parser: QuantityParser(&QuantityValidation{ + Min: k8sresource.MustParse("500Mi"), + }), + }, + { + StructField: "ExecutorMem", + StringValidation: &cr.StringValidation{ + Default: "500Mi", }, - { - StructField: "DriverMemOverhead", - StringPtrValidation: &cr.StringPtrValidation{ - Default: nil, // min(DriverMem * 0.4, 384Mi) - }, - Parser: QuantityParser(&QuantityValidation{ - Min: k8sresource.MustParse("0"), - }), + Parser: QuantityParser(&QuantityValidation{ + Min: k8sresource.MustParse("500Mi"), + }), + }, + { + StructField: "DriverMemOverhead", + StringPtrValidation: &cr.StringPtrValidation{ + Default: nil, // min(DriverMem * 0.4, 384Mi) }, - { - StructField: "ExecutorMemOverhead", - StringPtrValidation: &cr.StringPtrValidation{ - Default: nil, // min(ExecutorMem * 0.4, 384Mi) - }, - Parser: QuantityParser(&QuantityValidation{ - Min: k8sresource.MustParse("0"), - }), + Parser: QuantityParser(&QuantityValidation{ + Min: k8sresource.MustParse("0"), + }), + }, + { + StructField: "ExecutorMemOverhead", + StringPtrValidation: &cr.StringPtrValidation{ + Default: nil, // min(ExecutorMem * 0.4, 384Mi) }, - { - StructField: "MemOverheadFactor", - Float64PtrValidation: &cr.Float64PtrValidation{ - Default: nil, // set to 0.4 by Spark - GreaterThanOrEqualTo: pointer.Float64(0), - LessThan: pointer.Float64(1), - }, + Parser: QuantityParser(&QuantityValidation{ + Min: k8sresource.MustParse("0"), + }), + }, + { + StructField: "MemOverheadFactor", + Float64PtrValidation: &cr.Float64PtrValidation{ + Default: nil, // set to 0.4 by Spark + GreaterThanOrEqualTo: pointer.Float64(0), + LessThan: pointer.Float64(1), }, }, }, } +func sparkComputeFieldValidation(fieldName string) *cr.StructFieldValidation { + return &cr.StructFieldValidation{ + StructField: fieldName, + StructValidation: sparkComputeStructValidation, + } +} + func (sparkCompute *SparkCompute) ID() string { var buf bytes.Buffer buf.WriteString(s.Int32(sparkCompute.Executors)) diff --git a/pkg/operator/api/userconfig/models.go b/pkg/operator/api/userconfig/models.go index acc76e548d..6e30a3b79d 100644 --- a/pkg/operator/api/userconfig/models.go +++ b/pkg/operator/api/userconfig/models.go @@ -42,6 +42,7 @@ type Model struct { Training *ModelTraining `json:"training" yaml:"training"` Evaluation *ModelEvaluation `json:"evaluation" yaml:"evaluation"` Compute *TFCompute `json:"compute" yaml:"compute"` + DatasetCompute *SparkCompute `json:"dataset_compute" yaml:"dataset_compute"` Tags Tags `json:"tags" yaml:"tags"` } @@ -127,6 +128,7 @@ var modelValidation = &cr.StructValidation{ StructValidation: modelEvaluationValidation, }, tfComputeFieldValidation, + sparkComputeFieldValidation("DatasetCompute"), tagsFieldValidation, typeFieldValidation, }, diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index 79ca451718..0123589b88 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -96,7 +96,7 @@ var rawIntColumnFieldValidations = []*cr.StructFieldValidation{ AllowNull: true, }, }, - sparkComputeFieldValidation, + sparkComputeFieldValidation("Compute"), tagsFieldValidation, typeFieldValidation, } @@ -145,7 +145,7 @@ var rawFloatColumnFieldValidations = []*cr.StructFieldValidation{ AllowNull: true, }, }, - sparkComputeFieldValidation, + sparkComputeFieldValidation("Compute"), tagsFieldValidation, typeFieldValidation, } @@ -182,7 +182,7 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{ AllowNull: true, }, }, - sparkComputeFieldValidation, + sparkComputeFieldValidation("Compute"), tagsFieldValidation, typeFieldValidation, } diff --git a/pkg/operator/api/userconfig/transformed_columns.go b/pkg/operator/api/userconfig/transformed_columns.go index bdaf00ebab..23fd274132 100644 --- a/pkg/operator/api/userconfig/transformed_columns.go +++ b/pkg/operator/api/userconfig/transformed_columns.go @@ -51,7 +51,7 @@ var transformedColumnValidation = &cr.StructValidation{ }, }, inputValuesFieldValidation, - sparkComputeFieldValidation, + sparkComputeFieldValidation("Compute"), tagsFieldValidation, typeFieldValidation, }, diff --git a/pkg/operator/workloads/data_job.go b/pkg/operator/workloads/data_job.go index ccae3e25f9..7675512ddd 100644 --- a/pkg/operator/workloads/data_job.go +++ b/pkg/operator/workloads/data_job.go @@ -257,6 +257,7 @@ func dataWorkloadSpecs(ctx *context.Context) ([]*WorkloadSpec, error) { allComputes = append(allComputes, transformedColumn.Compute) } } + allComputes = append(allComputes, model.DatasetCompute) } resourceIDSet := strset.Union(rawColumnIDs, aggregateIDs, transformedColumnIDs, trainingDatasetIDs)