diff --git a/examples/movie-ratings/resources/environments.yaml b/examples/movie-ratings/resources/environments.yaml index a30f3baef9..a8604c6525 100644 --- a/examples/movie-ratings/resources/environments.yaml +++ b/examples/movie-ratings/resources/environments.yaml @@ -18,7 +18,3 @@ - kind: raw_column name: rating type: FLOAT_COLUMN - -- kind: raw_column - name: timestamp - type: INT_COLUMN diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index a88b7d7bfc..c6c75c4d38 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -100,39 +100,6 @@ def write_training_data(model_name, df, ctx, spark): return df -def expected_schema_from_context(ctx): - data_config = ctx.environment["data"] - - if data_config["type"] == "csv": - expected_field_names = data_config["schema"] - else: - expected_field_names = [f["raw_column_name"] for f in data_config["schema"]] - - schema_fields = [ - StructField( - name=fname, - dataType=CORTEX_TYPE_TO_SPARK_TYPE[ctx.columns[fname]["type"]], - nullable=not ctx.columns[fname].get("required", False), - ) - for fname in expected_field_names - ] - return StructType(schema_fields) - - -def compare_column_schemas(expected_schema, actual_schema): - # Nullables are being left out because when Spark is reading CSV files, it is setting nullable to true - # regardless of if the column has all values or not. The null checks will be done elsewhere. - # This compares only the schemas - expected_sorted_fields = sorted( - [(f.name, f.dataType) for f in expected_schema], key=lambda f: f[0] - ) - - # Sorted for determinism when testing - actual_sorted_fields = sorted([(f.name, f.dataType) for f in actual_schema], key=lambda f: f[0]) - - return expected_sorted_fields == actual_sorted_fields - - def min_check(input_col, min): return input_col >= min, input_col < min @@ -220,27 +187,48 @@ def value_check_data(ctx, df, raw_columns=None): def ingest(ctx, spark): - expected_schema = expected_schema_from_context(ctx) - if ctx.environment["data"]["type"] == "csv": df = read_csv(ctx, spark) elif ctx.environment["data"]["type"] == "parquet": df = read_parquet(ctx, spark) - if compare_column_schemas(expected_schema, df.schema) is not True: - logger.error("expected schema:") - log_df_schema(spark.createDataFrame([], expected_schema), logger.error) - logger.error("found schema:") - log_df_schema(df, logger.error) + input_type_map = {f.name: f.dataType for f in df.schema} + + for raw_column_name in ctx.raw_columns.keys(): + raw_column = ctx.raw_columns[raw_column_name] + expected_types = CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[raw_column["type"]] + actual_type = input_type_map[raw_column_name] + if actual_type not in expected_types: + logger.error("found schema:") + log_df_schema(df, logger.error) + + raise UserException( + "raw column " + raw_column_name, + "type mismatch", + "expected {} but found {}".format( + " or ".join(str(x) for x in expected_types), actual_type + ), + ) + target_type = CORTEX_TYPE_TO_SPARK_TYPE[raw_column["type"]] - raise UserException("raw data schema mismatch") + if target_type != actual_type: + df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_type)) - return df + return df.select(*sorted(df.columns)) def read_csv(ctx, spark): data_config = ctx.environment["data"] - schema = expected_schema_from_context(ctx) + expected_field_names = data_config["schema"] + + schema_fields = [] + for field_name in expected_field_names: + if field_name in ctx.raw_columns: + spark_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.raw_columns[field_name]["type"]] + else: + spark_type = StringType() + + schema_fields.append(StructField(name=field_name, dataType=spark_type)) csv_config = { util.snake_to_camel(param_name): val @@ -248,22 +236,29 @@ def read_csv(ctx, spark): if val is not None } - return spark.read.csv(data_config["path"], schema=schema, mode="FAILFAST", **csv_config) + df = spark.read.csv( + data_config["path"], schema=StructType(schema_fields), mode="FAILFAST", **csv_config + ) + return df.select(*ctx.raw_columns.keys()) def read_parquet(ctx, spark): parquet_config = ctx.environment["data"] df = spark.read.parquet(parquet_config["path"]) - parquet_columns = [c["parquet_column_name"] for c in parquet_config["schema"]] - missing_cols = util.subtract_lists(parquet_columns, df.columns) + alias_map = { + c["parquet_column_name"]: c["raw_column_name"] + for c in parquet_config["schema"] + if c["parquet_column_name"] in ctx.raw_columns + } + + missing_cols = set(alias_map.keys()) - set(df.columns) if len(missing_cols) > 0: - raise UserException("parquet dataset", "missing columns: " + str(missing_cols)) + logger.error("found schema:") + log_df_schema(df, logger.error) + raise UserException("missing column(s) in input dataset", str(missing_cols)) - selectExprs = [ - "{} as {}".format(c["parquet_column_name"], c["raw_column_name"]) - for c in parquet_config["schema"] - ] + selectExprs = ["{} as {}".format(alias_map[alias], alias) for alias in alias_map.keys()] return df.selectExpr(*selectExprs) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 8eaecc4e7f..5fe7840e85 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -27,127 +27,55 @@ pytestmark = pytest.mark.usefixtures("spark") -def test_compare_column_schemas(): - expected = StructType( - [ - StructField("a_float", FloatType()), - StructField("b_long", LongType()), - StructField("c_str", StringType()), - ] - ) - - missing_col = StructType( - [StructField("a_float", FloatType()), StructField("b_long", LongType())] - ) - - assert spark_util.compare_column_schemas(expected, missing_col) == False - - incorrect_type = StructType( - [ - StructField("b_long", LongType()), - StructField("a_float", FloatType()), - StructField("c_str", LongType()), - ] - ) - - assert spark_util.compare_column_schemas(expected, incorrect_type) == False - - actual = StructType( - [ - StructField("b_long", LongType()), - StructField("a_float", FloatType()), - StructField("c_str", StringType()), - ] - ) - - assert spark_util.compare_column_schemas(expected, actual) == True +def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): + csv_str = "\n".join(["a,0.1,", "b,1,1", "c,1.1,4"]) + path_to_file = write_csv_file(csv_str) -def test_get_expected_schema_from_context_csv(ctx_obj, get_context): ctx_obj["environment"] = { - "data": {"type": "csv", "schema": ["income", "years_employed", "prior_default"]} + "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} } + ctx_obj["raw_columns"] = { - "income": {"name": "income", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, - "years_employed": { - "name": "years_employed", - "type": "INT_COLUMN", - "required": False, - "id": "-", - }, - "prior_default": { - "name": "prior_default", - "type": "STRING_COLUMN", - "required": True, - "id": "-", - }, + "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, + "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, } - ctx = get_context(ctx_obj) - - expected_output = StructType( - [ - StructField("years_employed", LongType(), True), - StructField("income", FloatType(), False), - StructField("prior_default", StringType(), False), - ] - ) - - actual = spark_util.expected_schema_from_context(ctx) - assert spark_util.compare_column_schemas(actual, expected_output) == True - + assert spark_util.read_csv(get_context(ctx_obj), spark).count() == 3 -def test_get_expected_schema_from_context_parquet(ctx_obj, get_context): ctx_obj["environment"] = { "data": { - "type": "parquet", - "schema": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, - ], + "type": "csv", + "path": path_to_file, + "schema": ["a_str", "b_float", "c_long", "d_long"], } } - ctx_obj["raw_columns"] = { - "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, - "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, - "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, - } - - ctx = get_context(ctx_obj) - - expected_output = StructType( - [ - StructField("c_long", LongType(), True), - StructField("b_float", FloatType(), False), - StructField("a_str", StringType(), False), - ] - ) - actual = spark_util.expected_schema_from_context(ctx) - assert spark_util.compare_column_schemas(actual, expected_output) == True + assert spark_util.read_csv(get_context(ctx_obj), spark).count() == 3 -def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): +def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): csv_str = "\n".join(["a,0.1,", "b,1,1", "c,1.1,4"]) path_to_file = write_csv_file(csv_str) ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} + "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_long", "c_long"]} } ctx_obj["raw_columns"] = { "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, - "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + "b_long": {"name": "b_long", "type": "INT_COLUMN", "required": True, "id": "-"}, "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, } - assert spark_util.read_csv(get_context(ctx_obj), spark).count() == 3 + with pytest.raises(Py4JJavaError): + spark_util.ingest(get_context(ctx_obj), spark).collect() -def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): - csv_str = "\n".join(["a,0.1,", "b,1,1", "c,1.1,4"]) +def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): + csv_str = "\n".join(["a,0.1,", "b,1,1"]) path_to_file = write_csv_file(csv_str) @@ -161,7 +89,7 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, } - with pytest.raises(Py4JJavaError): + with pytest.raises(Py4JJavaError) as exec_info: spark_util.ingest(get_context(ctx_obj), spark).collect() @@ -200,7 +128,8 @@ def test_read_csv_valid_options(spark, write_csv_file, ctx_obj, get_context): "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, } - actual_results = spark_util.read_csv(get_context(ctx_obj), spark).collect() + result_df = spark_util.read_csv(get_context(ctx_obj), spark) + actual_results = result_df.select(*sorted(result_df.columns)).collect() assert len(actual_results) == 3 assert actual_results[0] == Row(a_str=" a ", b_float=float(1), c_long=None) @@ -380,6 +309,48 @@ def test_value_check_data_invalid_out_of_range(spark, ctx_obj, get_context): def test_ingest_parquet_valid(spark, write_parquet_file, ctx_obj, get_context): data = [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4)] + schema = StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", DoubleType()), + StructField("c_long", IntegerType()), + ] + ) + + path_to_file = write_parquet_file(spark, data, schema) + + ctx_obj["environment"] = { + "data": { + "type": "parquet", + "path": path_to_file, + "schema": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + } + } + + ctx_obj["raw_columns"] = { + "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "2"}, + "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "3"}, + } + + df = spark_util.ingest(get_context(ctx_obj), spark) + + assert df.count() == 3 + + assert sorted([(s.name, s.dataType) for s in df.schema], key=lambda x: x[0]) == [ + ("a_str", StringType()), + ("b_float", FloatType()), + ("c_long", LongType()), + ] + + +def test_ingest_parquet_extra_cols(spark, write_parquet_file, ctx_obj, get_context): + data = [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4)] + schema = StructType( [ StructField("a_str", StringType()), @@ -398,6 +369,7 @@ def test_ingest_parquet_valid(spark, write_parquet_file, ctx_obj, get_context): {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "d_long", "raw_column_name": "d_long"}, ], } } @@ -411,6 +383,42 @@ def test_ingest_parquet_valid(spark, write_parquet_file, ctx_obj, get_context): assert spark_util.ingest(get_context(ctx_obj), spark).count() == 3 +def test_ingest_parquet_missing_cols(spark, write_parquet_file, ctx_obj, get_context): + data = [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4)] + + schema = StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", FloatType()), + StructField("d_long", LongType()), + ] + ) + + path_to_file = write_parquet_file(spark, data, schema) + + ctx_obj["environment"] = { + "data": { + "type": "parquet", + "path": path_to_file, + "schema": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + } + } + + ctx_obj["raw_columns"] = { + "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "2"}, + "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "3"}, + } + + with pytest.raises(UserException) as exec_info: + spark_util.ingest(get_context(ctx_obj), spark).collect() + assert "c_long" in str(exec_info) and "missing column" in str(exec_info) + + def test_ingest_parquet_type_mismatch(spark, write_parquet_file, ctx_obj, get_context): data = [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4.0)] @@ -442,8 +450,9 @@ def test_ingest_parquet_type_mismatch(spark, write_parquet_file, ctx_obj, get_co "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "3"}, } - with pytest.raises(UserException): + with pytest.raises(UserException) as exec_info: spark_util.ingest(get_context(ctx_obj), spark).collect() + assert "c_long" in str(exec_info) and "type mismatch" in str(exec_info) def test_ingest_parquet_failed_requirements( @@ -562,7 +571,8 @@ def test_run_builtin_aggregators_error(spark, ctx_obj, get_context): data = [Row(a=None), Row(a=1), Row(a=2), Row(a=3)] df = spark.createDataFrame(data, StructType([StructField("a", LongType())])) - with pytest.raises(Exception): + + with pytest.raises(Exception) as exec_info: spark_util.run_builtin_aggregators(aggregate_list, df, ctx, spark) ctx.store_aggregate_result.assert_not_called()