Skip to content

Add ability to sample dataset #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/applications/resources/environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ Transfer data at scale from data warehouses like S3 into the Cortex environment.
```yaml
- kind: environment # (required)
name: <string> # environment name (required)
limit:
# specify `num_rows` or `fraction_of_rows` if using `limit`
num_rows: <int> # maximum number of rows to select from the dataset
fraction_of_rows: <float> # fraction of rows to select from the dataset
randomize: <bool> # flag to indicate random selection of data (exact dataset size will not be guaranteed when this flag is true)
random_seed: <int> # seed value for randomizing
log_level:
tensorflow: <string> # TensorFlow log level (DEBUG, INFO, WARN, ERROR, or FATAL) (default: INFO)
spark: <string> # Spark log level (ALL, TRACE, DEBUG, INFO, WARN, ERROR, or FATAL) (default: WARN)
Expand Down
7 changes: 7 additions & 0 deletions pkg/api/userconfig/config_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ const (
PathKey = "path"
ValueKey = "value"

// environment
LimitKey = "limit"
NumRowsKey = "num_rows"
FractionOfRowsKey = "fraction_of_rows"
RandomizeKey = "randomize"
RandomSeedKey = "random_seed"

// model
NumEpochsKey = "num_epochs"
NumStepsKey = "num_steps"
Expand Down
50 changes: 50 additions & 0 deletions pkg/api/userconfig/environments.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Environments []*Environment
type Environment struct {
Name string `json:"name" yaml:"name"`
LogLevel *LogLevel `json:"log_level" yaml:"log_level"`
Limit *Limit `json:"limit" yaml:"limit"`
Data Data `json:"-" yaml:"-"`
FilePath string `json:"file_path" yaml:"-"`
}
Expand All @@ -46,6 +47,10 @@ var environmentValidation = &cr.StructValidation{
StructField: "LogLevel",
StructValidation: logLevelValidation,
},
&cr.StructFieldValidation{
StructField: "Limit",
StructValidation: limitValidation,
},
&cr.StructFieldValidation{
StructField: "Data",
Key: "data",
Expand All @@ -55,6 +60,39 @@ var environmentValidation = &cr.StructValidation{
},
}

type Limit struct {
NumRows *int64 `json:"num_rows" yaml:"num_rows"`
FractionOfRows *float32 `json:"fraction_of_rows" yaml:"fraction_of_rows"`
Randomize *bool `json:"randomize" yaml:"randomize"`
RandomSeed *int64 `json:"random_seed" yaml:"random_seed"`
}

var limitValidation = &cr.StructValidation{
StructFieldValidations: []*cr.StructFieldValidation{
&cr.StructFieldValidation{
StructField: "NumRows",
Int64PtrValidation: &cr.Int64PtrValidation{
GreaterThan: util.Int64Ptr(0),
},
},
&cr.StructFieldValidation{
StructField: "FractionOfRows",
Float32PtrValidation: &cr.Float32PtrValidation{
GreaterThan: util.Float32Ptr(0),
LessThan: util.Float32Ptr(1),
},
},
&cr.StructFieldValidation{
StructField: "Randomize",
BoolPtrValidation: &cr.BoolPtrValidation{},
},
&cr.StructFieldValidation{
StructField: "RandomSeed",
Int64PtrValidation: &cr.Int64PtrValidation{},
},
},
}

type LogLevel struct {
Tensorflow string `json:"tensorflow" yaml:"tensorflow"`
Spark string `json:"spark" yaml:"spark"`
Expand Down Expand Up @@ -304,6 +342,18 @@ func (env *Environment) Validate() error {
return errors.Wrap(err, Identify(env))
}

if env.Limit != nil {
if env.Limit.NumRows != nil && env.Limit.FractionOfRows != nil {
return errors.Wrap(ErrorSpecifyOnlyOne(NumRowsKey, FractionOfRowsKey), Identify(env), LimitKey)
}
if env.Limit.Randomize != nil && env.Limit.NumRows == nil && env.Limit.FractionOfRows == nil {
return errors.Wrap(ErrorOneOfPrerequisitesNotDefined(RandomizeKey, LimitKey, FractionOfRowsKey), Identify(env))
}
if env.Limit.RandomSeed != nil && env.Limit.Randomize == nil {
return errors.Wrap(ErrorOneOfPrerequisitesNotDefined(RandomSeedKey, RandomizeKey), Identify(env))
}
}

dups := util.FindDuplicateStrs(env.Data.GetIngestedColumns())
if len(dups) > 0 {
return errors.New(Identify(env), DataKey, SchemaKey, "column name", s.ErrDuplicatedValue(dups[0]))
Expand Down
11 changes: 11 additions & 0 deletions pkg/api/userconfig/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (
ErrColumnMustBeRaw
ErrSpecifyAllOrNone
ErrSpecifyOnlyOne
ErrOneOfPrerequisitesNotDefined
ErrTemplateExtraArg
ErrTemplateMissingArg
ErrInvalidColumnInputType
Expand Down Expand Up @@ -73,6 +74,7 @@ var errorKinds = []string{
"err_column_must_be_raw",
"err_specify_all_or_none",
"err_specify_only_one",
"err_one_of_prerequisites_not_defined",
"err_template_extra_arg",
"err_template_missing_arg",
"err_invalid_column_input_type",
Expand Down Expand Up @@ -250,6 +252,15 @@ func ErrorSpecifyOnlyOne(vals ...string) error {
}
}

func ErrorOneOfPrerequisitesNotDefined(argName string, prerequisites ...string) error {
message := fmt.Sprintf("%s specified without specifying %s", s.UserStr(argName), s.UserStrsOr(prerequisites))

return ConfigError{
Kind: ErrOneOfPrerequisitesNotDefined,
message: message,
}
}

func ErrorTemplateExtraArg(template *Template, argName string) error {
return ConfigError{
Kind: ErrTemplateExtraArg,
Expand Down
1 change: 1 addition & 0 deletions pkg/operator/context/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func dataID(config *userconfig.Config, datasetVersion string) string {
for _, rawColumnConfig := range config.RawColumns {
rawColumnTypeMap[rawColumnConfig.GetName()] = rawColumnConfig.GetType()
}
buf.WriteString(s.Obj(config.Environment.Limit))
buf.WriteString(s.Obj(rawColumnTypeMap))

data := config.Environment.Data
Expand Down
62 changes: 40 additions & 22 deletions pkg/workloads/spark_job/spark_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,27 +108,32 @@ def validate_dataset(ctx, raw_df, cols_to_validate):
raise UserException("raw column validations failed")


def limit_dataset(full_dataset_size, ingest_df, limit_config):
max_rows = full_dataset_size
if limit_config.get("num_rows") is not None:
max_rows = min(limit_config["num_rows"], full_dataset_size)
fraction = float(max_rows) / full_dataset_size
elif limit_config.get("fraction_of_rows") is not None:
fraction = limit_config["fraction_of_rows"]
max_rows = round(full_dataset_size * fraction)
if max_rows == full_dataset_size:
return ingest_df
if limit_config["randomize"]:
fraction = min(
fraction * 1.1, 1.0 # increase the odds of getting the desired target row count
)
ingest_df = ingest_df.sample(fraction=fraction, seed=limit_config["random_seed"])
logger.info("Selecting a subset of data of at most {} rows".format(max_rows))
return ingest_df.limit(max_rows)


def write_raw_dataset(df, ctx, spark):
logger.info("Caching {} data (version: {})".format(ctx.app["name"], ctx.dataset_version))
acc, df = spark_util.accumulate_count(df, spark)
df.write.mode("overwrite").parquet(aws.s3a_path(ctx.bucket, ctx.raw_dataset["key"]))
return acc.value


def drop_null_and_write(ingest_df, ctx, spark):
full_dataset_size = ingest_df.count()
logger.info("Dropping any rows that contain null values")
ingest_df = ingest_df.dropna()
written_count = write_raw_dataset(ingest_df, ctx, spark)
metadata = {"dataset_size": written_count}
aws.upload_json_to_s3(metadata, ctx.raw_dataset["metadata_key"], ctx.bucket)
logger.info(
"{} rows read, {} rows dropped, {} rows ingested".format(
full_dataset_size, full_dataset_size - written_count, written_count
)
)


def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest):
if should_ingest:
cols_to_validate = list(ctx.rf_id_map.keys())
Expand All @@ -141,18 +146,31 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest):
ctx.upload_resource_status_start(*col_resources_to_validate)
try:
if should_ingest:
data_config = ctx.environment["data"]

logger.info("Ingesting")
logger.info(
"Ingesting {} data from {}".format(ctx.app["name"], ctx.environment["data"]["path"])
)
logger.info("Ingesting {} data from {}".format(ctx.app["name"], data_config["path"]))
ingest_df = spark_util.ingest(ctx, spark)

if ctx.environment["data"].get("drop_null"):
drop_null_and_write(ingest_df, ctx, spark)
full_dataset_size = ingest_df.count()

if data_config.get("drop_null"):
logger.info("Dropping any rows that contain null values")
ingest_df = ingest_df.dropna()

if ctx.environment.get("limit"):
ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"])

written_count = write_raw_dataset(ingest_df, ctx, spark)
metadata = {"dataset_size": written_count}
aws.upload_json_to_s3(metadata, ctx.raw_dataset["metadata_key"], ctx.bucket)
if written_count != full_dataset_size:
logger.info(
"{} rows read, {} rows dropped, {} rows ingested".format(
full_dataset_size, full_dataset_size - written_count, written_count
)
)
else:
written_count = write_raw_dataset(ingest_df, ctx, spark)
metadata = {"dataset_size": written_count}
aws.upload_json_to_s3(metadata, ctx.raw_dataset["metadata_key"], ctx.bucket)
logger.info("{} rows ingested".format(written_count))

logger.info("Reading {} data (version: {})".format(ctx.app["name"], ctx.dataset_version))
Expand Down