Skip to content

Commit 4b1426b

Browse files
Add support for temporal tables, columns with defaults
* Enumerate columns with SQL defaults and skip them if they are not in the input dataframe * Enumerate system-generated temporal columns when on AzureSQL DB or SQL 2016+, and skip them * Enumerate graph node ID columns if the option hideGraphColumns is true (which it is by default) and skip them
1 parent e9e0b73 commit 4b1426b

File tree

1 file changed

+52
-20
lines changed

1 file changed

+52
-20
lines changed

src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -187,27 +187,59 @@ object BulkCopyUtils extends Logging {
187187
private[spark] def getComputedCols(
188188
conn: Connection,
189189
table: String,
190+
dfColNames: List[String],
190191
hideGraphColumns: Boolean): List[String] = {
191192
// TODO can optimize this, also evaluate SQLi issues
192-
val queryStr = if (hideGraphColumns) s"""IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14)
193-
exec sp_executesql N'SELECT name
194-
FROM sys.computed_columns
195-
WHERE object_id = OBJECT_ID(''${table}'')
196-
UNION ALL
197-
SELECT C.name
198-
FROM sys.tables AS T
199-
JOIN sys.columns AS C
200-
ON T.object_id = C.object_id
201-
WHERE T.object_id = OBJECT_ID(''${table}'')
202-
AND (T.is_edge = 1 OR T.is_node = 1)
203-
AND C.is_hidden = 0
204-
AND C.graph_type = 2'
205-
ELSE
206-
SELECT name
207-
FROM sys.computed_columns
208-
WHERE object_id = OBJECT_ID('${table}')
193+
var queryStr = s"""
194+
-- First, enumerate all the computed columns for this table
195+
declare @sql nvarchar(max)
196+
SET @sql = 'SELECT name
197+
FROM sys.computed_columns
198+
WHERE object_id = OBJECT_ID(''${table}'')
199+
200+
-- then, ignore columns which have defaults associated with them, but only if they are NOT in the input dataframe
201+
'
202+
SET @sql = @sql + '
203+
UNION
204+
SELECT C.name
205+
FROM sys.tables AS T
206+
JOIN sys.columns AS C
207+
ON T.object_id = C.object_id
208+
WHERE T.object_id = OBJECT_ID(''${table}'')
209+
AND C.default_object_id != 0
210+
AND C.name NOT IN (
211+
""" + dfColNames.mkString("''", "'',''", "''") + s""")
212+
'
213+
-- then, consider the graph ID columns for graph tables if we are on SQL 2017+ or on Azure SQL DB
214+
IF ('true' = '""" + hideGraphColumns + s"""'
215+
AND (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14))
216+
SET @sql = @sql + '
217+
UNION
218+
SELECT C.name
219+
FROM sys.tables AS T
220+
JOIN sys.columns AS C
221+
ON T.object_id = C.object_id
222+
WHERE T.object_id = OBJECT_ID(''${table}'')
223+
AND (T.is_edge = 1 OR T.is_node = 1)
224+
AND C.is_hidden = 0
225+
AND C.graph_type = 2
226+
'
227+
228+
-- consider system-generated columns for temporal tables if we are on SQL 2016+ or on Azure SQL DB
229+
IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 13)
230+
SET @sql = @sql + '
231+
UNION
232+
SELECT C.name
233+
FROM sys.tables AS T
234+
JOIN sys.columns AS C
235+
ON T.object_id = C.object_id
236+
WHERE T.object_id = OBJECT_ID(''${table}'')
237+
AND T.temporal_type != 0
238+
AND C.generated_always_type != 0
239+
'
240+
241+
exec sp_executesql @sql
209242
"""
210-
else s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"
211243

212244
val computedColRs = conn.createStatement.executeQuery(queryStr)
213245
val computedCols = ListBuffer[String]()
@@ -325,16 +357,16 @@ SELECT name
325357
zip df.schema.fieldNames.toList).toMap
326358
val dfCols = df.schema
327359

360+
val dfColNames = df.schema.fieldNames.toList
328361
val tableCols = getSchema(rs, JdbcDialects.get(url))
329-
val computedCols = getComputedCols(conn, dbtable, hideGraphColumns)
362+
val computedCols = getComputedCols(conn, dbtable, dfColNames, hideGraphColumns)
330363

331364
val prefix = "Spark Dataframe and SQL Server table have differing"
332365

333366
if (computedCols.length == 0) {
334367
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
335368
s"${prefix} numbers of columns")
336369
} else if (strictSchemaCheck) {
337-
val dfColNames = df.schema.fieldNames.toList
338370
val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
339371
// if df has computed column(s), check column length using non computed column in df and table.
340372
// non computed column number in df: dfCols.length - dfComputedColCt

0 commit comments

Comments
 (0)