Skip to content

Commit 2e857db

Browse files
authored
Allow users to ingest a subset of input columns (#92)
1 parent c0ea684 commit 2e857db

File tree

3 files changed

+152
-151
lines changed

3 files changed

+152
-151
lines changed

examples/movie-ratings/resources/environments.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,3 @@
1818
- kind: raw_column
1919
name: rating
2020
type: FLOAT_COLUMN
21-
22-
- kind: raw_column
23-
name: timestamp
24-
type: INT_COLUMN

pkg/workloads/spark_job/spark_util.py

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -100,39 +100,6 @@ def write_training_data(model_name, df, ctx, spark):
100100
return df
101101

102102

103-
def expected_schema_from_context(ctx):
104-
data_config = ctx.environment["data"]
105-
106-
if data_config["type"] == "csv":
107-
expected_field_names = data_config["schema"]
108-
else:
109-
expected_field_names = [f["raw_column_name"] for f in data_config["schema"]]
110-
111-
schema_fields = [
112-
StructField(
113-
name=fname,
114-
dataType=CORTEX_TYPE_TO_SPARK_TYPE[ctx.columns[fname]["type"]],
115-
nullable=not ctx.columns[fname].get("required", False),
116-
)
117-
for fname in expected_field_names
118-
]
119-
return StructType(schema_fields)
120-
121-
122-
def compare_column_schemas(expected_schema, actual_schema):
123-
# Nullables are being left out because when Spark is reading CSV files, it is setting nullable to true
124-
# regardless of if the column has all values or not. The null checks will be done elsewhere.
125-
# This compares only the schemas
126-
expected_sorted_fields = sorted(
127-
[(f.name, f.dataType) for f in expected_schema], key=lambda f: f[0]
128-
)
129-
130-
# Sorted for determinism when testing
131-
actual_sorted_fields = sorted([(f.name, f.dataType) for f in actual_schema], key=lambda f: f[0])
132-
133-
return expected_sorted_fields == actual_sorted_fields
134-
135-
136103
def min_check(input_col, min):
137104
return input_col >= min, input_col < min
138105

@@ -220,50 +187,78 @@ def value_check_data(ctx, df, raw_columns=None):
220187

221188

222189
def ingest(ctx, spark):
223-
expected_schema = expected_schema_from_context(ctx)
224-
225190
if ctx.environment["data"]["type"] == "csv":
226191
df = read_csv(ctx, spark)
227192
elif ctx.environment["data"]["type"] == "parquet":
228193
df = read_parquet(ctx, spark)
229194

230-
if compare_column_schemas(expected_schema, df.schema) is not True:
231-
logger.error("expected schema:")
232-
log_df_schema(spark.createDataFrame([], expected_schema), logger.error)
233-
logger.error("found schema:")
234-
log_df_schema(df, logger.error)
195+
input_type_map = {f.name: f.dataType for f in df.schema}
196+
197+
for raw_column_name in ctx.raw_columns.keys():
198+
raw_column = ctx.raw_columns[raw_column_name]
199+
expected_types = CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[raw_column["type"]]
200+
actual_type = input_type_map[raw_column_name]
201+
if actual_type not in expected_types:
202+
logger.error("found schema:")
203+
log_df_schema(df, logger.error)
204+
205+
raise UserException(
206+
"raw column " + raw_column_name,
207+
"type mismatch",
208+
"expected {} but found {}".format(
209+
" or ".join(str(x) for x in expected_types), actual_type
210+
),
211+
)
212+
target_type = CORTEX_TYPE_TO_SPARK_TYPE[raw_column["type"]]
235213

236-
raise UserException("raw data schema mismatch")
214+
if target_type != actual_type:
215+
df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_type))
237216

238-
return df
217+
return df.select(*sorted(df.columns))
239218

240219

241220
def read_csv(ctx, spark):
242221
data_config = ctx.environment["data"]
243-
schema = expected_schema_from_context(ctx)
222+
expected_field_names = data_config["schema"]
223+
224+
schema_fields = []
225+
for field_name in expected_field_names:
226+
if field_name in ctx.raw_columns:
227+
spark_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.raw_columns[field_name]["type"]]
228+
else:
229+
spark_type = StringType()
230+
231+
schema_fields.append(StructField(name=field_name, dataType=spark_type))
244232

245233
csv_config = {
246234
util.snake_to_camel(param_name): val
247235
for param_name, val in data_config.get("csv_config", {}).items()
248236
if val is not None
249237
}
250238

251-
return spark.read.csv(data_config["path"], schema=schema, mode="FAILFAST", **csv_config)
239+
df = spark.read.csv(
240+
data_config["path"], schema=StructType(schema_fields), mode="FAILFAST", **csv_config
241+
)
242+
return df.select(*ctx.raw_columns.keys())
252243

253244

254245
def read_parquet(ctx, spark):
255246
parquet_config = ctx.environment["data"]
256247
df = spark.read.parquet(parquet_config["path"])
257248

258-
parquet_columns = [c["parquet_column_name"] for c in parquet_config["schema"]]
259-
missing_cols = util.subtract_lists(parquet_columns, df.columns)
249+
alias_map = {
250+
c["parquet_column_name"]: c["raw_column_name"]
251+
for c in parquet_config["schema"]
252+
if c["parquet_column_name"] in ctx.raw_columns
253+
}
254+
255+
missing_cols = set(alias_map.keys()) - set(df.columns)
260256
if len(missing_cols) > 0:
261-
raise UserException("parquet dataset", "missing columns: " + str(missing_cols))
257+
logger.error("found schema:")
258+
log_df_schema(df, logger.error)
259+
raise UserException("missing column(s) in input dataset", str(missing_cols))
262260

263-
selectExprs = [
264-
"{} as {}".format(c["parquet_column_name"], c["raw_column_name"])
265-
for c in parquet_config["schema"]
266-
]
261+
selectExprs = ["{} as {}".format(alias_map[alias], alias) for alias in alias_map.keys()]
267262

268263
return df.selectExpr(*selectExprs)
269264

0 commit comments

Comments
 (0)