@@ -187,27 +187,59 @@ object BulkCopyUtils extends Logging {
187187 private [spark] def getComputedCols (
188188 conn : Connection ,
189189 table : String ,
190+ dfColNames : List [String ],
190191 hideGraphColumns : Boolean ): List [String ] = {
191192 // TODO can optimize this, also evaluate SQLi issues
192- val queryStr = if (hideGraphColumns) s """ IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14)
193- exec sp_executesql N'SELECT name
194- FROM sys.computed_columns
195- WHERE object_id = OBJECT_ID('' ${table}'')
196- UNION ALL
197- SELECT C.name
198- FROM sys.tables AS T
199- JOIN sys.columns AS C
200- ON T.object_id = C.object_id
201- WHERE T.object_id = OBJECT_ID('' ${table}'')
202- AND (T.is_edge = 1 OR T.is_node = 1)
203- AND C.is_hidden = 0
204- AND C.graph_type = 2'
205- ELSE
206- SELECT name
207- FROM sys.computed_columns
208- WHERE object_id = OBJECT_ID(' ${table}')
193+ var queryStr = s """
194+ -- First, enumerate all the computed columns for this table
195+ declare @sql nvarchar(max)
196+ SET @sql = 'SELECT name
197+ FROM sys.computed_columns
198+ WHERE object_id = OBJECT_ID('' ${table}'')
199+
200+ -- then, ignore columns which have defaults associated with them, but only if they are NOT in the input dataframe
201+ '
202+ SET @sql = @sql + '
203+ UNION
204+ SELECT C.name
205+ FROM sys.tables AS T
206+ JOIN sys.columns AS C
207+ ON T.object_id = C.object_id
208+ WHERE T.object_id = OBJECT_ID('' ${table}'')
209+ AND C.default_object_id != 0
210+ AND C.name NOT IN (
211+ """ + dfColNames.mkString(" ''" , " '',''" , " ''" ) + s """ )
212+ '
213+ -- then, consider the graph ID columns for graph tables if we are on SQL 2017+ or on Azure SQL DB
214+ IF ('true' = ' """ + hideGraphColumns + s """ '
215+ AND (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14))
216+ SET @sql = @sql + '
217+ UNION
218+ SELECT C.name
219+ FROM sys.tables AS T
220+ JOIN sys.columns AS C
221+ ON T.object_id = C.object_id
222+ WHERE T.object_id = OBJECT_ID('' ${table}'')
223+ AND (T.is_edge = 1 OR T.is_node = 1)
224+ AND C.is_hidden = 0
225+ AND C.graph_type = 2
226+ '
227+
228+ -- consider system-generated columns for temporal tables if we are on SQL 2016+ or on Azure SQL DB
229+ IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 13)
230+ SET @sql = @sql + '
231+ UNION
232+ SELECT C.name
233+ FROM sys.tables AS T
234+ JOIN sys.columns AS C
235+ ON T.object_id = C.object_id
236+ WHERE T.object_id = OBJECT_ID('' ${table}'')
237+ AND T.temporal_type != 0
238+ AND C.generated_always_type != 0
239+ '
240+
241+ exec sp_executesql @sql
209242 """
210- else s " SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID(' ${table}'); "
211243
212244 val computedColRs = conn.createStatement.executeQuery(queryStr)
213245 val computedCols = ListBuffer [String ]()
@@ -325,16 +357,16 @@ SELECT name
325357 zip df.schema.fieldNames.toList).toMap
326358 val dfCols = df.schema
327359
360+ val dfColNames = df.schema.fieldNames.toList
328361 val tableCols = getSchema(rs, JdbcDialects .get(url))
329- val computedCols = getComputedCols(conn, dbtable, hideGraphColumns)
362+ val computedCols = getComputedCols(conn, dbtable, dfColNames, hideGraphColumns)
330363
331364 val prefix = " Spark Dataframe and SQL Server table have differing"
332365
333366 if (computedCols.length == 0 ) {
334367 assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
335368 s " ${prefix} numbers of columns " )
336369 } else if (strictSchemaCheck) {
337- val dfColNames = df.schema.fieldNames.toList
338370 val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
339371 // if df has computed column(s), check column length using non computed column in df and table.
340372 // non computed column number in df: dfCols.length - dfComputedColCt
0 commit comments