@@ -100,39 +100,6 @@ def write_training_data(model_name, df, ctx, spark):
100
100
return df
101
101
102
102
103
- def expected_schema_from_context (ctx ):
104
- data_config = ctx .environment ["data" ]
105
-
106
- if data_config ["type" ] == "csv" :
107
- expected_field_names = data_config ["schema" ]
108
- else :
109
- expected_field_names = [f ["raw_column_name" ] for f in data_config ["schema" ]]
110
-
111
- schema_fields = [
112
- StructField (
113
- name = fname ,
114
- dataType = CORTEX_TYPE_TO_SPARK_TYPE [ctx .columns [fname ]["type" ]],
115
- nullable = not ctx .columns [fname ].get ("required" , False ),
116
- )
117
- for fname in expected_field_names
118
- ]
119
- return StructType (schema_fields )
120
-
121
-
122
- def compare_column_schemas (expected_schema , actual_schema ):
123
- # Nullables are being left out because when Spark is reading CSV files, it is setting nullable to true
124
- # regardless of if the column has all values or not. The null checks will be done elsewhere.
125
- # This compares only the schemas
126
- expected_sorted_fields = sorted (
127
- [(f .name , f .dataType ) for f in expected_schema ], key = lambda f : f [0 ]
128
- )
129
-
130
- # Sorted for determinism when testing
131
- actual_sorted_fields = sorted ([(f .name , f .dataType ) for f in actual_schema ], key = lambda f : f [0 ])
132
-
133
- return expected_sorted_fields == actual_sorted_fields
134
-
135
-
136
103
def min_check (input_col , min ):
137
104
return input_col >= min , input_col < min
138
105
@@ -220,50 +187,78 @@ def value_check_data(ctx, df, raw_columns=None):
220
187
221
188
222
189
def ingest (ctx , spark ):
223
- expected_schema = expected_schema_from_context (ctx )
224
-
225
190
if ctx .environment ["data" ]["type" ] == "csv" :
226
191
df = read_csv (ctx , spark )
227
192
elif ctx .environment ["data" ]["type" ] == "parquet" :
228
193
df = read_parquet (ctx , spark )
229
194
230
- if compare_column_schemas (expected_schema , df .schema ) is not True :
231
- logger .error ("expected schema:" )
232
- log_df_schema (spark .createDataFrame ([], expected_schema ), logger .error )
233
- logger .error ("found schema:" )
234
- log_df_schema (df , logger .error )
195
+ input_type_map = {f .name : f .dataType for f in df .schema }
196
+
197
+ for raw_column_name in ctx .raw_columns .keys ():
198
+ raw_column = ctx .raw_columns [raw_column_name ]
199
+ expected_types = CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES [raw_column ["type" ]]
200
+ actual_type = input_type_map [raw_column_name ]
201
+ if actual_type not in expected_types :
202
+ logger .error ("found schema:" )
203
+ log_df_schema (df , logger .error )
204
+
205
+ raise UserException (
206
+ "raw column " + raw_column_name ,
207
+ "type mismatch" ,
208
+ "expected {} but found {}" .format (
209
+ " or " .join (str (x ) for x in expected_types ), actual_type
210
+ ),
211
+ )
212
+ target_type = CORTEX_TYPE_TO_SPARK_TYPE [raw_column ["type" ]]
235
213
236
- raise UserException ("raw data schema mismatch" )
214
+ if target_type != actual_type :
215
+ df = df .withColumn (raw_column_name , F .col (raw_column_name ).cast (target_type ))
237
216
238
- return df
217
+ return df . select ( * sorted ( df . columns ))
239
218
240
219
241
220
def read_csv (ctx , spark ):
242
221
data_config = ctx .environment ["data" ]
243
- schema = expected_schema_from_context (ctx )
222
+ expected_field_names = data_config ["schema" ]
223
+
224
+ schema_fields = []
225
+ for field_name in expected_field_names :
226
+ if field_name in ctx .raw_columns :
227
+ spark_type = CORTEX_TYPE_TO_SPARK_TYPE [ctx .raw_columns [field_name ]["type" ]]
228
+ else :
229
+ spark_type = StringType ()
230
+
231
+ schema_fields .append (StructField (name = field_name , dataType = spark_type ))
244
232
245
233
csv_config = {
246
234
util .snake_to_camel (param_name ): val
247
235
for param_name , val in data_config .get ("csv_config" , {}).items ()
248
236
if val is not None
249
237
}
250
238
251
- return spark .read .csv (data_config ["path" ], schema = schema , mode = "FAILFAST" , ** csv_config )
239
+ df = spark .read .csv (
240
+ data_config ["path" ], schema = StructType (schema_fields ), mode = "FAILFAST" , ** csv_config
241
+ )
242
+ return df .select (* ctx .raw_columns .keys ())
252
243
253
244
254
245
def read_parquet (ctx , spark ):
255
246
parquet_config = ctx .environment ["data" ]
256
247
df = spark .read .parquet (parquet_config ["path" ])
257
248
258
- parquet_columns = [c ["parquet_column_name" ] for c in parquet_config ["schema" ]]
259
- missing_cols = util .subtract_lists (parquet_columns , df .columns )
249
+ alias_map = {
250
+ c ["parquet_column_name" ]: c ["raw_column_name" ]
251
+ for c in parquet_config ["schema" ]
252
+ if c ["parquet_column_name" ] in ctx .raw_columns
253
+ }
254
+
255
+ missing_cols = set (alias_map .keys ()) - set (df .columns )
260
256
if len (missing_cols ) > 0 :
261
- raise UserException ("parquet dataset" , "missing columns: " + str (missing_cols ))
257
+ logger .error ("found schema:" )
258
+ log_df_schema (df , logger .error )
259
+ raise UserException ("missing column(s) in input dataset" , str (missing_cols ))
262
260
263
- selectExprs = [
264
- "{} as {}" .format (c ["parquet_column_name" ], c ["raw_column_name" ])
265
- for c in parquet_config ["schema" ]
266
- ]
261
+ selectExprs = ["{} as {}" .format (alias_map [alias ], alias ) for alias in alias_map .keys ()]
267
262
268
263
return df .selectExpr (* selectExprs )
269
264
0 commit comments