diff --git a/docs/applications/resources/environments.md b/docs/applications/resources/environments.md index af7cc5d9f1..8185e50f80 100644 --- a/docs/applications/resources/environments.md +++ b/docs/applications/resources/environments.md @@ -7,6 +7,12 @@ Transfer data at scale from data warehouses like S3 into the Cortex environment. ```yaml - kind: environment # (required) name: # environment name (required) + limit: + # specify `num_rows` or `fraction_of_rows` if using `limit` + num_rows: # maximum number of rows to select from the dataset + fraction_of_rows: # fraction of rows to select from the dataset + randomize: # flag to indicate random selection of data (exact dataset size will not be guaranteed when this flag is true) + random_seed: # seed value for randomizing log_level: tensorflow: # TensorFlow log level (DEBUG, INFO, WARN, ERROR, or FATAL) (default: INFO) spark: # Spark log level (ALL, TRACE, DEBUG, INFO, WARN, ERROR, or FATAL) (default: WARN) diff --git a/pkg/api/userconfig/config_key.go b/pkg/api/userconfig/config_key.go index 4fa2969065..14524fa794 100644 --- a/pkg/api/userconfig/config_key.go +++ b/pkg/api/userconfig/config_key.go @@ -36,6 +36,13 @@ const ( PathKey = "path" ValueKey = "value" + // environment + LimitKey = "limit" + NumRowsKey = "num_rows" + FractionOfRowsKey = "fraction_of_rows" + RandomizeKey = "randomize" + RandomSeedKey = "random_seed" + // model NumEpochsKey = "num_epochs" NumStepsKey = "num_steps" diff --git a/pkg/api/userconfig/environments.go b/pkg/api/userconfig/environments.go index d3132090d5..a5f7040814 100644 --- a/pkg/api/userconfig/environments.go +++ b/pkg/api/userconfig/environments.go @@ -29,6 +29,7 @@ type Environments []*Environment type Environment struct { Name string `json:"name" yaml:"name"` LogLevel *LogLevel `json:"log_level" yaml:"log_level"` + Limit *Limit `json:"limit" yaml:"limit"` Data Data `json:"-" yaml:"-"` FilePath string `json:"file_path" yaml:"-"` } @@ -46,6 +47,10 @@ var environmentValidation = &cr.StructValidation{ StructField: "LogLevel", StructValidation: logLevelValidation, }, + &cr.StructFieldValidation{ + StructField: "Limit", + StructValidation: limitValidation, + }, &cr.StructFieldValidation{ StructField: "Data", Key: "data", @@ -55,6 +60,39 @@ var environmentValidation = &cr.StructValidation{ }, } +type Limit struct { + NumRows *int64 `json:"num_rows" yaml:"num_rows"` + FractionOfRows *float32 `json:"fraction_of_rows" yaml:"fraction_of_rows"` + Randomize *bool `json:"randomize" yaml:"randomize"` + RandomSeed *int64 `json:"random_seed" yaml:"random_seed"` +} + +var limitValidation = &cr.StructValidation{ + StructFieldValidations: []*cr.StructFieldValidation{ + &cr.StructFieldValidation{ + StructField: "NumRows", + Int64PtrValidation: &cr.Int64PtrValidation{ + GreaterThan: util.Int64Ptr(0), + }, + }, + &cr.StructFieldValidation{ + StructField: "FractionOfRows", + Float32PtrValidation: &cr.Float32PtrValidation{ + GreaterThan: util.Float32Ptr(0), + LessThan: util.Float32Ptr(1), + }, + }, + &cr.StructFieldValidation{ + StructField: "Randomize", + BoolPtrValidation: &cr.BoolPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "RandomSeed", + Int64PtrValidation: &cr.Int64PtrValidation{}, + }, + }, +} + type LogLevel struct { Tensorflow string `json:"tensorflow" yaml:"tensorflow"` Spark string `json:"spark" yaml:"spark"` @@ -304,6 +342,18 @@ func (env *Environment) Validate() error { return errors.Wrap(err, Identify(env)) } + if env.Limit != nil { + if env.Limit.NumRows != nil && env.Limit.FractionOfRows != nil { + return errors.Wrap(ErrorSpecifyOnlyOne(NumRowsKey, FractionOfRowsKey), Identify(env), LimitKey) + } + if env.Limit.Randomize != nil && env.Limit.NumRows == nil && env.Limit.FractionOfRows == nil { + return errors.Wrap(ErrorOneOfPrerequisitesNotDefined(RandomizeKey, LimitKey, FractionOfRowsKey), Identify(env)) + } + if env.Limit.RandomSeed != nil && env.Limit.Randomize == nil { + return errors.Wrap(ErrorOneOfPrerequisitesNotDefined(RandomSeedKey, RandomizeKey), Identify(env)) + } + } + dups := util.FindDuplicateStrs(env.Data.GetIngestedColumns()) if len(dups) > 0 { return errors.New(Identify(env), DataKey, SchemaKey, "column name", s.ErrDuplicatedValue(dups[0])) diff --git a/pkg/api/userconfig/errors.go b/pkg/api/userconfig/errors.go index c6f359409a..3b1e382601 100644 --- a/pkg/api/userconfig/errors.go +++ b/pkg/api/userconfig/errors.go @@ -44,6 +44,7 @@ const ( ErrColumnMustBeRaw ErrSpecifyAllOrNone ErrSpecifyOnlyOne + ErrOneOfPrerequisitesNotDefined ErrTemplateExtraArg ErrTemplateMissingArg ErrInvalidColumnInputType @@ -73,6 +74,7 @@ var errorKinds = []string{ "err_column_must_be_raw", "err_specify_all_or_none", "err_specify_only_one", + "err_one_of_prerequisites_not_defined", "err_template_extra_arg", "err_template_missing_arg", "err_invalid_column_input_type", @@ -250,6 +252,15 @@ func ErrorSpecifyOnlyOne(vals ...string) error { } } +func ErrorOneOfPrerequisitesNotDefined(argName string, prerequisites ...string) error { + message := fmt.Sprintf("%s specified without specifying %s", s.UserStr(argName), s.UserStrsOr(prerequisites)) + + return ConfigError{ + Kind: ErrOneOfPrerequisitesNotDefined, + message: message, + } +} + func ErrorTemplateExtraArg(template *Template, argName string) error { return ConfigError{ Kind: ErrTemplateExtraArg, diff --git a/pkg/operator/context/environment.go b/pkg/operator/context/environment.go index 311ce0a2a2..c130dd676a 100644 --- a/pkg/operator/context/environment.go +++ b/pkg/operator/context/environment.go @@ -40,6 +40,7 @@ func dataID(config *userconfig.Config, datasetVersion string) string { for _, rawColumnConfig := range config.RawColumns { rawColumnTypeMap[rawColumnConfig.GetName()] = rawColumnConfig.GetType() } + buf.WriteString(s.Obj(config.Environment.Limit)) buf.WriteString(s.Obj(rawColumnTypeMap)) data := config.Environment.Data diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 2df38d35ed..c967b418ac 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -108,6 +108,25 @@ def validate_dataset(ctx, raw_df, cols_to_validate): raise UserException("raw column validations failed") +def limit_dataset(full_dataset_size, ingest_df, limit_config): + max_rows = full_dataset_size + if limit_config.get("num_rows") is not None: + max_rows = min(limit_config["num_rows"], full_dataset_size) + fraction = float(max_rows) / full_dataset_size + elif limit_config.get("fraction_of_rows") is not None: + fraction = limit_config["fraction_of_rows"] + max_rows = round(full_dataset_size * fraction) + if max_rows == full_dataset_size: + return ingest_df + if limit_config["randomize"]: + fraction = min( + fraction * 1.1, 1.0 # increase the odds of getting the desired target row count + ) + ingest_df = ingest_df.sample(fraction=fraction, seed=limit_config["random_seed"]) + logger.info("Selecting a subset of data of at most {} rows".format(max_rows)) + return ingest_df.limit(max_rows) + + def write_raw_dataset(df, ctx, spark): logger.info("Caching {} data (version: {})".format(ctx.app["name"], ctx.dataset_version)) acc, df = spark_util.accumulate_count(df, spark) @@ -115,20 +134,6 @@ def write_raw_dataset(df, ctx, spark): return acc.value -def drop_null_and_write(ingest_df, ctx, spark): - full_dataset_size = ingest_df.count() - logger.info("Dropping any rows that contain null values") - ingest_df = ingest_df.dropna() - written_count = write_raw_dataset(ingest_df, ctx, spark) - metadata = {"dataset_size": written_count} - aws.upload_json_to_s3(metadata, ctx.raw_dataset["metadata_key"], ctx.bucket) - logger.info( - "{} rows read, {} rows dropped, {} rows ingested".format( - full_dataset_size, full_dataset_size - written_count, written_count - ) - ) - - def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): if should_ingest: cols_to_validate = list(ctx.rf_id_map.keys()) @@ -141,18 +146,31 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): ctx.upload_resource_status_start(*col_resources_to_validate) try: if should_ingest: + data_config = ctx.environment["data"] + logger.info("Ingesting") - logger.info( - "Ingesting {} data from {}".format(ctx.app["name"], ctx.environment["data"]["path"]) - ) + logger.info("Ingesting {} data from {}".format(ctx.app["name"], data_config["path"])) ingest_df = spark_util.ingest(ctx, spark) - if ctx.environment["data"].get("drop_null"): - drop_null_and_write(ingest_df, ctx, spark) + full_dataset_size = ingest_df.count() + + if data_config.get("drop_null"): + logger.info("Dropping any rows that contain null values") + ingest_df = ingest_df.dropna() + + if ctx.environment.get("limit"): + ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"]) + + written_count = write_raw_dataset(ingest_df, ctx, spark) + metadata = {"dataset_size": written_count} + aws.upload_json_to_s3(metadata, ctx.raw_dataset["metadata_key"], ctx.bucket) + if written_count != full_dataset_size: + logger.info( + "{} rows read, {} rows dropped, {} rows ingested".format( + full_dataset_size, full_dataset_size - written_count, written_count + ) + ) else: - written_count = write_raw_dataset(ingest_df, ctx, spark) - metadata = {"dataset_size": written_count} - aws.upload_json_to_s3(metadata, ctx.raw_dataset["metadata_key"], ctx.bucket) logger.info("{} rows ingested".format(written_count)) logger.info("Reading {} data (version: {})".format(ctx.app["name"], ctx.dataset_version))