diff --git a/dataframe-jdbc/build.gradle.kts b/dataframe-jdbc/build.gradle.kts index 6e7e8dd42c..a9e5ba43c9 100644 --- a/dataframe-jdbc/build.gradle.kts +++ b/dataframe-jdbc/build.gradle.kts @@ -28,6 +28,7 @@ dependencies { testImplementation(libs.mssql) testImplementation(libs.junit) testImplementation(libs.sl4j) + testImplementation(libs.jts) testImplementation(libs.kotestAssertions) { exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8") } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt index cecfc64116..c2e98a949a 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt @@ -10,33 +10,67 @@ import kotlin.reflect.KType /** * Represents the H2 database type. * - * This class provides methods to convert data from a ResultSet to the appropriate type for H2, + * This class provides methods to convert data from a ResultSet to the appropriate type for H2 * and to generate the corresponding column schema. * - * NOTE: All date and timestamp related types are converted to String to avoid java.sql.* types. + * NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types. */ -public object H2 : DbType("h2") { +public class H2(public val dialect: DbType = MySql) : DbType("h2") { + init { + require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" } + } + + /** + * It contains constants related to different database modes. + * + * The mode value is used in the [extractDBTypeFromConnection] function to determine the corresponding `DbType` for the H2 database connection URL. + * For example, if the URL contains the mode value "MySQL", the H2 instance with the MySQL database type is returned. + * Otherwise, the `DbType` is determined based on the URL without the mode value. + * + * @see [extractDBTypeFromConnection] + * @see [createH2Instance] + */ + public companion object { + /** It represents the mode value "MySQL" for the H2 database. */ + public const val MODE_MYSQL: String = "MySQL" + + /** It represents the mode value "PostgreSQL" for the H2 database. */ + public const val MODE_POSTGRESQL: String = "PostgreSQL" + + /** It represents the mode value "MSSQLServer" for the H2 database. */ + public const val MODE_MSSQLSERVER: String = "MSSQLServer" + + /** It represents the mode value "MariaDB" for the H2 database. */ + public const val MODE_MARIADB: String = "MariaDB" + } + override val driverClassName: String get() = "org.h2.Driver" override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { - return null + return dialect.convertSqlTypeToColumnSchemaValue(tableColumnMetadata) } override fun isSystemTable(tableMetadata: TableMetadata): Boolean { - return tableMetadata.name.lowercase(Locale.getDefault()).contains("sys_") || - tableMetadata.schemaName?.lowercase(Locale.getDefault())?.contains("information_schema") ?: false + val locale = Locale.getDefault() + fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true + val schemaName = tableMetadata.schemaName + + // could be extended for other symptoms of the system tables for H2 + val isH2SystemTable = schemaName.containsWithLowercase("information_schema") + + return isH2SystemTable || dialect.isSystemTable(tableMetadata) } override fun buildTableMetadata(tables: ResultSet): TableMetadata { - return TableMetadata( - tables.getString("TABLE_NAME"), - tables.getString("TABLE_SCHEM"), - tables.getString("TABLE_CAT") - ) + return dialect.buildTableMetadata(tables) } override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { - return null + return dialect.convertSqlTypeToKType(tableColumnMetadata) + } + + public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String { + return dialect.sqlQueryLimit(sqlQuery, limit) } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt index 05aed59a78..db2c1b853e 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -4,9 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.io.TableMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet -import java.util.* +import java.util.Locale import kotlin.reflect.KType -import kotlin.reflect.full.createType /** * Represents the MSSQL database type. diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt index 793b41a93e..2463bd9fe8 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt @@ -1,18 +1,75 @@ package org.jetbrains.kotlinx.dataframe.io.db +import io.github.oshai.kotlinlogging.KotlinLogging +import java.sql.Connection import java.sql.SQLException +import java.util.Locale + +private val logger = KotlinLogging.logger {} + +/** + * Extracts the database type from the given connection. + * + * @param [connection] the database connection. + * @return the corresponding [DbType]. + * @throws [IllegalStateException] if URL information is missing in connection meta-data. + * @throws [IllegalArgumentException] if the URL specifies an unsupported database type. + * @throws [SQLException] if the URL is null. + */ +public fun extractDBTypeFromConnection(connection: Connection): DbType { + val url = connection.metaData?.url ?: throw IllegalStateException("URL information is missing in connection meta data!") + logger.info { "Processing DB type extraction for connection url: $url" } + + return if (url.contains(H2().dbTypeInJdbcUrl)) { + // works only for H2 version 2 + val modeQuery = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'" + var mode = "" + connection.createStatement().use { st -> + st.executeQuery( + modeQuery + ).use { rs -> + if (rs.next()) { + mode = rs.getString("SETTING_VALUE") + logger.debug { "Fetched H2 DB mode: $mode" } + } else { + throw IllegalStateException("The information about H2 mode is not found in the H2 meta-data!") + } + } + } + + // H2 doesn't support MariaDB and SQLite + when (mode.lowercase(Locale.getDefault())) { + H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql) + H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql) + H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql) + H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb) + else -> { + val message = "Unsupported database type in the url: $url. " + + "Only MySQL, MariaDB, MSSQL and PostgreSQL are supported!" + logger.error { message } + + throw IllegalArgumentException(message) + } + } + } else { + val dbType = extractDBTypeFromUrl(url) + logger.info { "Identified DB type as $dbType from url: $url" } + dbType + } +} /** * Extracts the database type from the given JDBC URL. * * @param [url] the JDBC URL. * @return the corresponding [DbType]. - * @throws RuntimeException if the url is null. + * @throws [RuntimeException] if the url is null. */ public fun extractDBTypeFromUrl(url: String?): DbType { if (url != null) { + val helperH2Instance = H2() return when { - H2.dbTypeInJdbcUrl in url -> H2 + helperH2Instance.dbTypeInJdbcUrl in url -> createH2Instance(url) MariaDb.dbTypeInJdbcUrl in url -> MariaDb MySql.dbTypeInJdbcUrl in url -> MySql Sqlite.dbTypeInJdbcUrl in url -> Sqlite @@ -28,6 +85,37 @@ public fun extractDBTypeFromUrl(url: String?): DbType { } } +/** + * Creates an instance of DbType based on the provided JDBC URL. + * + * @param [url] The JDBC URL representing the database connection. + * @return The corresponding [DbType] instance. + * @throws [IllegalArgumentException] if the provided URL does not contain a valid mode. + */ +private fun createH2Instance(url: String): DbType { + val modePattern = "MODE=(.*?);".toRegex() + val matchResult = modePattern.find(url) + + val mode: String = if (matchResult != null && matchResult.groupValues.size == 2) { + matchResult.groupValues[1] + } else { + throw IllegalArgumentException("The provided URL `$url` does not contain a valid mode.") + } + + // H2 doesn't support MariaDB and SQLite + return when (mode.lowercase(Locale.getDefault())) { + H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql) + H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql) + H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql) + H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb) + + else -> throw IllegalArgumentException( + "Unsupported database mode: $mode. " + + "Only MySQL, MariaDB, MSSQL, PostgreSQL modes are supported!" + ) + } +} + /** * Retrieves the driver class name from the given JDBC URL. * diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index b9778395c8..46061d0b3d 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -7,8 +7,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.api.Infer import org.jetbrains.kotlinx.dataframe.api.toDataFrame import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl -import org.jetbrains.kotlinx.dataframe.io.db.DbType -import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl +import org.jetbrains.kotlinx.dataframe.io.db.* import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema import java.math.BigDecimal @@ -138,7 +137,7 @@ public fun DataFrame.Companion.readSqlTable( inferNullability: Boolean = true, ): AnyFrame { val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit) else "SELECT * FROM $tableName" @@ -203,8 +202,7 @@ public fun DataFrame.Companion.readSqlQuery( "Also it should not contain any separators like `;`." } - val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery @@ -283,8 +281,7 @@ public fun DataFrame.Companion.readResultSet( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, ): AnyFrame { - val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) return readResultSet(resultSet, dbType, limit, inferNullability) } @@ -329,8 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables( inferNullability: Boolean = true, ): Map { val metaData = connection.metaData - val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) // exclude a system and other tables without data, but it looks like it is supported badly for many databases val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE")) @@ -390,8 +386,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable( connection: Connection, tableName: String ): DataFrameSchema { - val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) val sqlQuery = "SELECT * FROM $tableName" val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1) @@ -432,8 +427,7 @@ public fun DataFrame.Companion.getSchemaForSqlQuery( * @see DriverManager.getConnection */ public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema { - val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) connection.createStatement().use { st -> st.executeQuery(sqlQuery).use { rs -> @@ -468,8 +462,7 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp * @return the schema of the [ResultSet] as a [DataFrameSchema] object. */ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema { - val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) val tableColumns = getTableColumnsMetadata(resultSet) return buildSchemaByTableColumns(tableColumns, dbType) @@ -495,8 +488,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfig */ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map { val metaData = connection.metaData - val url = connection.metaData.url - val dbType = extractDBTypeFromUrl(url) + val dbType = extractDBTypeFromConnection(connection) val tableTypes = arrayOf("TABLE") // exclude a system and other tables without data diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt similarity index 94% rename from dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt rename to dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt index 4b74017894..def477d2df 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt @@ -1,13 +1,28 @@ -package org.jetbrains.kotlinx.dataframe.io +package org.jetbrains.kotlinx.dataframe.io.h2 import io.kotest.assertions.throwables.shouldThrow +import io.kotest.assertions.throwables.shouldThrowExactly import io.kotest.matchers.shouldBe import org.h2.jdbc.JdbcSQLSyntaxErrorException import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema -import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.api.add +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.schema +import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration import org.jetbrains.kotlinx.dataframe.io.db.H2 +import org.jetbrains.kotlinx.dataframe.io.db.MySql +import org.jetbrains.kotlinx.dataframe.io.getSchemaForAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.getSchemaForResultSet +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readResultSet +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Test @@ -18,7 +33,7 @@ import java.sql.ResultSet import java.sql.SQLException import kotlin.reflect.typeOf -private const val URL = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false" +private const val URL = "jdbc:h2:mem:test5;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false" @DataSchema interface Customer { @@ -350,7 +365,7 @@ class JdbcTest { val selectStatement = "SELECT * FROM Customer" st.executeQuery(selectStatement).use { rs -> - val df = DataFrame.readResultSet(rs, H2).cast() + val df = DataFrame.readResultSet(rs, H2(MySql)).cast() df.rowsCount() shouldBe 4 df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 @@ -358,7 +373,7 @@ class JdbcTest { rs.beforeFirst() - val df1 = DataFrame.readResultSet(rs, H2, 1).cast() + val df1 = DataFrame.readResultSet(rs, H2(MySql), 1).cast() df1.rowsCount() shouldBe 1 df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 @@ -366,7 +381,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema = DataFrame.getSchemaForResultSet(rs, H2) + val dataSchema = DataFrame.getSchemaForResultSet(rs, H2(MySql)) dataSchema.columns.size shouldBe 3 dataSchema.columns["name"]!!.type shouldBe typeOf() @@ -406,7 +421,7 @@ class JdbcTest { for (i in 1..10) { rs.beforeFirst() - val df1 = DataFrame.readResultSet(rs, H2, 2).cast() + val df1 = DataFrame.readResultSet(rs, H2(MySql), 2).cast() df1.rowsCount() shouldBe 2 df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 @@ -767,7 +782,7 @@ class JdbcTest { st.executeQuery(selectStatement).use { rs -> // ith default inferNullability: Boolean = true - val df4 = DataFrame.readResultSet(rs, H2) + val df4 = DataFrame.readResultSet(rs, H2(MySql)) df4.schema().columns["id"]!!.type shouldBe typeOf() df4.schema().columns["name"]!!.type shouldBe typeOf() df4.schema().columns["surname"]!!.type shouldBe typeOf() @@ -775,7 +790,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2) + val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2(MySql)) dataSchema3.columns.size shouldBe 4 dataSchema3.columns["id"]!!.type shouldBe typeOf() dataSchema3.columns["name"]!!.type shouldBe typeOf() @@ -785,7 +800,7 @@ class JdbcTest { // with inferNullability: Boolean = false rs.beforeFirst() - val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false) + val df5 = DataFrame.readResultSet(rs, H2(MySql), inferNullability = false) df5.schema().columns["id"]!!.type shouldBe typeOf() df5.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls df5.schema().columns["surname"]!!.type shouldBe typeOf() @@ -796,4 +811,12 @@ class JdbcTest { connection.createStatement().execute("DROP TABLE TestTable1") } + + @Test + fun `check require throws exception when specifying H2 database with H2 dialect`() { + val exception = shouldThrowExactly { + H2(H2()) + } + exception.message shouldBe "H2 database could not be specified with H2 dialect!" + } } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mariadbH2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mariadbH2Test.kt new file mode 100644 index 0000000000..32ae571b23 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mariadbH2Test.kt @@ -0,0 +1,407 @@ +package org.jetbrains.kotlinx.dataframe.io.h2 + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.add +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import kotlin.reflect.typeOf + +private const val URL = "jdbc:h2:mem:test1;DB_CLOSE_DELAY=-1;MODE=MariaDB;DATABASE_TO_LOWER=TRUE" + +@DataSchema +interface Table1MariaDb { + val id: Int + val bitCol: Boolean + val tinyintcol: Int + val smallintcol: Short? + val mediumintcol: Int + val mediumintunsignedcol: Int + val integercol: Int + val intCol: Int + val integerunsignedcol: Long + val bigintcol: Long + val floatcol: Float + val doublecol: Double + val decimalcol: BigDecimal + val dateCol: String + val datetimeCol: String + val timestampCol: String + val timeCol: String + val yearCol: String + val varcharCol: String + val charCol: String + val binaryCol: ByteArray + val varbinaryCol: ByteArray + val tinyblobCol: ByteArray + val blobCol: ByteArray + val mediumblobCol: ByteArray + val longblobCol: ByteArray + val textCol: String + val mediumtextCol: String + val longtextCol: String + val enumCol: String + val jsonCol: String +} + +@DataSchema +interface Table2MariaDb { + val id: Int + val bitCol: Boolean? + val tinyintCol: Int? + val smallintCol: Int? + val mediumintCol: Int? + val mediumintUnsignedCol: Int? + val integercol: Int? + val intCol: Int? + val integerUnsignedCol: Long? + val bigintCol: Long? + val floatCol: Float? + val doubleCol: Double? + val decimalCol: Double? + val dateCol: String? + val datetimeCol: String? + val timestampCol: String? + val timeCol: String? + val yearCol: String? + val varcharCol: String? + val charCol: String? + val binaryCol: ByteArray? + val varbinaryCol: ByteArray? + val tinyblobCol: ByteArray? + val blobCol: ByteArray? + val mediumblobCol: ByteArray? + val longblobCol: ByteArray? + val textCol: String? + val mediumtextCol: String? + val longtextCol: String? + val enumCol: String? + val jsonCol: String? +} + +@DataSchema +interface Table3MariaDb { + val id: Int + val enumCol: String + val setCol: Char? +} + +private const val JSON_STRING = + "{\"details\": {\"foodType\": \"Pizza\", \"menu\": \"https://www.loumalnatis.com/our-menu\"}, \n" + + " \t\"favorites\": [{\"description\": \"Pepperoni deep dish\", \"price\": 18.75}, \n" + + "{\"description\": \"The Lou\", \"price\": 24.75}]}" + +class MariadbH2Test { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL) + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE IF NOT EXISTS table1 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT NOT NULL, + tinyintCol TINYINT NOT NULL, + smallintCol SMALLINT, + mediumintCol MEDIUMINT NOT NULL, + mediumintUnsignedCol MEDIUMINT UNSIGNED NOT NULL, + integerCol INTEGER NOT NULL, + intCol INT NOT NULL, + integerUnsignedCol INTEGER UNSIGNED NOT NULL, + bigintCol BIGINT NOT NULL, + floatCol FLOAT NOT NULL, + doubleCol DOUBLE NOT NULL, + decimalCol DECIMAL NOT NULL, + dateCol DATE NOT NULL, + datetimeCol DATETIME NOT NULL, + timestampCol TIMESTAMP NOT NULL, + timeCol TIME NOT NULL, + yearCol YEAR NOT NULL, + varcharCol VARCHAR(255) NOT NULL, + charCol CHAR(10) NOT NULL, + binaryCol BINARY(64) NOT NULL, + varbinaryCol VARBINARY(128) NOT NULL, + tinyblobCol TINYBLOB NOT NULL, + blobCol BLOB NOT NULL, + mediumblobCol MEDIUMBLOB NOT NULL , + longblobCol LONGBLOB NOT NULL, + textCol TEXT NOT NULL, + mediumtextCol MEDIUMTEXT NOT NULL, + longtextCol LONGTEXT NOT NULL, + enumCol ENUM('Value1', 'Value2', 'Value3') NOT NULL, + jsonCol JSON NOT NULL + ) + """ + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val createTableQuery2 = """ + CREATE TABLE IF NOT EXISTS table2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT, + tinyintCol TINYINT, + smallintCol SMALLINT, + mediumintCol MEDIUMINT, + mediumintUnsignedCol MEDIUMINT UNSIGNED, + integerCol INTEGER, + intCol INT, + integerUnsignedCol INTEGER UNSIGNED, + bigintCol BIGINT, + floatCol FLOAT, + doubleCol DOUBLE, + decimalCol DECIMAL, + dateCol DATE, + datetimeCol DATETIME, + timestampCol TIMESTAMP, + timeCol TIME, + yearCol YEAR, + varcharCol VARCHAR(255), + charCol CHAR(10), + binaryCol BINARY(64), + varbinaryCol VARBINARY(128), + tinyblobCol TINYBLOB, + blobCol BLOB, + mediumblobCol MEDIUMBLOB, + longblobCol LONGBLOB, + textCol TEXT, + mediumtextCol MEDIUMTEXT, + longtextCol LONGTEXT, + enumCol ENUM('Value1', 'Value2', 'Value3') + ) + """ + connection.createStatement().execute( + createTableQuery2.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO table1 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, jsonCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + @Language("SQL") + val insertData2 = """ + INSERT INTO table2 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + connection.prepareStatement(insertData1).use { st -> + // Insert data into table1 + for (i in 1..3) { + st.setBoolean(1, true) + st.setByte(2, i.toByte()) + st.setShort(3, (i * 10).toShort()) + st.setInt(4, i * 100) + st.setInt(5, i * 100) + st.setInt(6, i * 100) + st.setInt(7, i * 100) + st.setInt(8, i * 100) + st.setInt(9, i * 100) + st.setFloat(10, i * 10.0f) + st.setDouble(11, i * 10.0) + st.setBigDecimal(12, BigDecimal(i * 10)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, "textValue$i") + st.setString(27, "mediumtextValue$i") + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, JSON_STRING) + + st.executeUpdate() + } + } + + connection.prepareStatement(insertData2).use { st -> + // Insert data into table2 + for (i in 1..3) { + st.setBoolean(1, false) + st.setByte(2, (i * 2).toByte()) + st.setShort(3, (i * 20).toShort()) + st.setInt(4, i * 200) + st.setInt(5, i * 200) + st.setInt(6, i * 200) + st.setInt(7, i * 200) + st.setInt(8, i * 200) + st.setInt(9, i * 200) + st.setFloat(10, i * 20.0f) + st.setDouble(11, i * 20.0) + st.setBigDecimal(12, BigDecimal(i * 20)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, null) + st.setString(27, null) + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `basic test for reading sql tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + val result = df1.filter { it[Table1MariaDb::id] == 1 } + result[0][26] shouldBe "textValue1" + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["textcol"]!!.type shouldBe typeOf() + + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + val result2 = df2.filter { it[Table2MariaDb::id] == 1 } + result2[0][26] shouldBe null + + val schema2 = DataFrame.getSchemaForSqlTable(connection, "table2") + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["textcol"]!!.type shouldBe typeOf() + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + t1.id, + t1.enumCol + FROM table1 t1 + JOIN table2 t2 ON t1.id = t2.id + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + val result = df.filter { it[Table3MariaDb::id] == 1 } + result[0][1] shouldBe "Value1" + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["enumcol"]!!.type shouldBe typeOf() + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection, limit = 1000).values.toList() + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1MariaDb::integercol] > 100 }.rowsCount() shouldBe 2 + table1Df[0][11] shouldBe 10.0 + table1Df[0][26] shouldBe "textValue1" + + val table2Df = dataframes[1].cast() + + table2Df.rowsCount() shouldBe 3 + table2Df.filter { it[Table2MariaDb::integercol] != null && it[Table2MariaDb::integercol]!! > 400 } + .rowsCount() shouldBe 1 + table2Df[0][11] shouldBe 20.0 + table2Df[0][26] shouldBe null + } + + @Test + fun `reading numeric types`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + + val result = df1.select("tinyintcol") + .add("tinyintcol2") { it[Table1MariaDb::tinyintcol] } + + result[0][1] shouldBe 1 + + val result2 = df1.select("mediumintcol") + .add("mediumintcol2") { it[Table1MariaDb::mediumintcol] } + + result2[0][1] shouldBe 100 + + val result3 = df1.select("mediumintunsignedcol") + .add("mediumintunsignedcol2") { it[Table1MariaDb::mediumintunsignedcol] } + + result3[0][1] shouldBe 100 + + val result5 = df1.select("bigintcol") + .add("bigintcol2") { it[Table1MariaDb::bigintcol] } + + result5[0][1] shouldBe 100 + + val result7 = df1.select("doublecol") + .add("doublecol2") { it[Table1MariaDb::doublecol] } + + result7[0][1] shouldBe 10.0 + + val result8 = df1.select("decimalcol") + .add("decimalcol2") { it[Table1MariaDb::decimalcol] } + + result8[0][1] shouldBe BigDecimal("10") + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + + schema.columns["tinyintcol"]!!.type shouldBe typeOf() + schema.columns["smallintcol"]!!.type shouldBe typeOf() + schema.columns["mediumintcol"]!!.type shouldBe typeOf() + schema.columns["mediumintunsignedcol"]!!.type shouldBe typeOf() + schema.columns["bigintcol"]!!.type shouldBe typeOf() + schema.columns["floatcol"]!!.type shouldBe typeOf() + schema.columns["doublecol"]!!.type shouldBe typeOf() + schema.columns["decimalcol"]!!.type shouldBe typeOf() + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mssqlH2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mssqlH2Test.kt new file mode 100644 index 0000000000..4cff761740 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mssqlH2Test.kt @@ -0,0 +1,359 @@ +package org.jetbrains.kotlinx.dataframe.io.h2 + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.schema +import org.jetbrains.kotlinx.dataframe.io.db.MsSql +import org.jetbrains.kotlinx.dataframe.io.getSchemaForResultSet +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readResultSet +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.ResultSet +import java.sql.SQLException +import java.util.Date +import java.util.UUID +import kotlin.reflect.typeOf + +private const val URL = "jdbc:h2:mem:testmssql;DB_CLOSE_DELAY=-1;MODE=MSSQLServer;DATABASE_TO_UPPER=FALSE;CASE_INSENSITIVE_IDENTIFIERS=TRUE" + +@DataSchema +interface Table1MSSSQL { + val id: Int + val bigintColumn: Long + val binaryColumn: ByteArray + val bitColumn: Boolean + val charColumn: Char + val dateColumn: Date + val datetime3Column: java.sql.Timestamp + val datetime2Column: java.sql.Timestamp + val decimalColumn: BigDecimal + val floatColumn: Double + val imageColumn: ByteArray? + val intColumn: Int + val moneyColumn: BigDecimal + val ncharColumn: Char + val ntextColumn: String + val numericColumn: BigDecimal + val nvarcharColumn: String + val nvarcharMaxColumn: String + val realColumn: Float + val smalldatetimeColumn: java.sql.Timestamp + val smallintColumn: Int + val smallmoneyColumn: BigDecimal + val timeColumn: java.sql.Time + val timestampColumn: java.sql.Timestamp + val tinyintColumn: Int + val uniqueidentifierColumn: Char + val varbinaryColumn: ByteArray + val varbinaryMaxColumn: ByteArray + val varcharColumn: String + val varcharMaxColumn: String + val geometryColumn: String + val geographyColumn: String +} + +class MSSQLH2Test { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL) + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE Table1 ( + id INT NOT NULL IDENTITY PRIMARY KEY, + bigintColumn BIGINT, + binaryColumn BINARY(50), + bitColumn BIT, + charColumn CHAR(10), + dateColumn DATE, + datetime3Column DATETIME2(3), + datetime2Column DATETIME2, + decimalColumn DECIMAL(10,2), + floatColumn FLOAT, + imageColumn IMAGE, + intColumn INT, + moneyColumn MONEY, + ncharColumn NCHAR(10), + ntextColumn NTEXT, + numericColumn NUMERIC(10,2), + nvarcharColumn NVARCHAR(50), + nvarcharMaxColumn NVARCHAR(MAX), + realColumn REAL, + smalldatetimeColumn SMALLDATETIME, + smallintColumn SMALLINT, + smallmoneyColumn SMALLMONEY, + textColumn TEXT, + timeColumn TIME, + timestampColumn DATETIME2, + tinyintColumn TINYINT, + uniqueidentifierColumn UNIQUEIDENTIFIER, + varbinaryColumn VARBINARY(50), + varbinaryMaxColumn VARBINARY(MAX), + varcharColumn VARCHAR(50), + varcharMaxColumn VARCHAR(MAX) + ); + """ + + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO Table1 ( + bigintColumn, binaryColumn, bitColumn, charColumn, dateColumn, datetime3Column, datetime2Column, + decimalColumn, floatColumn, imageColumn, intColumn, moneyColumn, ncharColumn, + ntextColumn, numericColumn, nvarcharColumn, nvarcharMaxColumn, realColumn, smalldatetimeColumn, + smallintColumn, smallmoneyColumn, textColumn, timeColumn, timestampColumn, tinyintColumn, + uniqueidentifierColumn, varbinaryColumn, varbinaryMaxColumn, varcharColumn, varcharMaxColumn + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + connection.prepareStatement(insertData1).use { st -> + for (i in 1..5) { + st.setLong(1, 123456789012345L) // bigintColumn + st.setBytes(2, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // binaryColumn + st.setBoolean(3, true) // bitColumn + st.setString(4, "Sample") // charColumn + st.setDate(5, java.sql.Date(System.currentTimeMillis())) // dateColumn + st.setTimestamp(6, java.sql.Timestamp(System.currentTimeMillis())) // datetime3Column + st.setTimestamp(7, java.sql.Timestamp(System.currentTimeMillis())) // datetime2Column + st.setBigDecimal(8, BigDecimal("12345.67")) // decimalColumn + st.setFloat(9, 123.45f) // floatColumn + st.setNull(10, java.sql.Types.NULL) // imageColumn (assuming nullable) + st.setInt(11, 123456) // intColumn + st.setBigDecimal(12, BigDecimal("123.45")) // moneyColumn + st.setString(13, "Sample") // ncharColumn + st.setString(14, "Sample$i text") // ntextColumn + st.setBigDecimal(15, BigDecimal("1234.56")) // numericColumn + st.setString(16, "Sample") // nvarcharColumn + st.setString(17, "Sample$i text") // nvarcharMaxColumn + st.setFloat(18, 123.45f) // realColumn + st.setTimestamp(19, java.sql.Timestamp(System.currentTimeMillis())) // smalldatetimeColumn + st.setInt(20, 123) // smallintColumn + st.setBigDecimal(21, BigDecimal("123.45")) // smallmoneyColumn + st.setString(22, "Sample$i text") // textColumn + st.setTime(23, java.sql.Time(System.currentTimeMillis())) // timeColumn + st.setTimestamp(24, java.sql.Timestamp(System.currentTimeMillis())) // timestampColumn + st.setInt(25, 123) // tinyintColumn + //st.setObject(27, null) // udtColumn (assuming nullable) + st.setObject(26, UUID.randomUUID()) // uniqueidentifierColumn + st.setBytes(27, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryColumn + st.setBytes(28, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryMaxColumn + st.setString(29, "Sample$i") // varcharColumn + st.setString(30, "Sample$i text") // varcharMaxColumn + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `basic test for reading sql tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1", limit = 5).cast() + + val result = df1.filter { it[Table1MSSSQL::id] == 1 } + result[0][30] shouldBe "Sample1 text" + result[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + result[0][Table1MSSSQL::bitColumn] shouldBe true + result[0][Table1MSSSQL::intColumn] shouldBe 123456 + result[0][Table1MSSSQL::ntextColumn] shouldBe "Sample1 text" + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["bigintColumn"]!!.type shouldBe typeOf() + schema.columns["binaryColumn"]!!.type shouldBe typeOf() + schema.columns["bitColumn"]!!.type shouldBe typeOf() + schema.columns["charColumn"]!!.type shouldBe typeOf() + schema.columns["dateColumn"]!!.type shouldBe typeOf() + schema.columns["datetime3Column"]!!.type shouldBe typeOf() + schema.columns["datetime2Column"]!!.type shouldBe typeOf() + schema.columns["decimalColumn"]!!.type shouldBe typeOf() + schema.columns["intColumn"]!!.type shouldBe typeOf() + schema.columns["moneyColumn"]!!.type shouldBe typeOf() + schema.columns["ncharColumn"]!!.type shouldBe typeOf() + schema.columns["ntextColumn"]!!.type shouldBe typeOf() + schema.columns["numericColumn"]!!.type shouldBe typeOf() + schema.columns["nvarcharColumn"]!!.type shouldBe typeOf() + schema.columns["nvarcharMaxColumn"]!!.type shouldBe typeOf() + schema.columns["realColumn"]!!.type shouldBe typeOf() + schema.columns["smalldatetimeColumn"]!!.type shouldBe typeOf() + schema.columns["smallintColumn"]!!.type shouldBe typeOf() + schema.columns["smallmoneyColumn"]!!.type shouldBe typeOf() + schema.columns["timeColumn"]!!.type shouldBe typeOf() + schema.columns["timestampColumn"]!!.type shouldBe typeOf() + schema.columns["tinyintColumn"]!!.type shouldBe typeOf() + schema.columns["varbinaryColumn"]!!.type shouldBe typeOf() + schema.columns["varbinaryMaxColumn"]!!.type shouldBe typeOf() + schema.columns["varcharColumn"]!!.type shouldBe typeOf() + schema.columns["varcharMaxColumn"]!!.type shouldBe typeOf() + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + Table1.id, + Table1.bigintColumn + FROM Table1 + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery, limit = 3).cast() + val result = df.filter { it[Table1MSSSQL::id] == 1 } + result[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["bigintColumn"]!!.type shouldBe typeOf() + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection, limit = 4).values.toList() + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 4 + table1Df.filter { it[Table1MSSSQL::id] > 2 }.rowsCount() shouldBe 2 + table1Df[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + } + + @Test + fun `infer nullability`() { + // prepare tables and data + @Language("SQL") + val createTestTable1Query = """ + CREATE TABLE TestTable1 ( + id INT PRIMARY KEY, + name VARCHAR(50), + surname VARCHAR(50), + age INT NOT NULL + ) + """ + + connection.createStatement().execute(createTestTable1Query) + + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)") + + // start testing `readSqlTable` method + + // with default inferNullability: Boolean = true + val tableName = "TestTable1" + val df = DataFrame.readSqlTable(connection, tableName) + df.schema().columns["id"]!!.type shouldBe typeOf() + df.schema().columns["name"]!!.type shouldBe typeOf() + df.schema().columns["surname"]!!.type shouldBe typeOf() + df.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) + dataSchema.columns.size shouldBe 4 + dataSchema.columns["id"]!!.type shouldBe typeOf() + dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema.columns["surname"]!!.type shouldBe typeOf() + dataSchema.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false) + df1.schema().columns["id"]!!.type shouldBe typeOf() + df1.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df1.schema().columns["surname"]!!.type shouldBe typeOf() + df1.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSqlTable` method + + // start testing `readSQLQuery` method + + // ith default inferNullability: Boolean = true + @Language("SQL") + val sqlQuery = """ + SELECT name, surname, age FROM TestTable1 + """.trimIndent() + + val df2 = DataFrame.readSqlQuery(connection, sqlQuery) + df2.schema().columns["name"]!!.type shouldBe typeOf() + df2.schema().columns["surname"]!!.type shouldBe typeOf() + df2.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + dataSchema2.columns.size shouldBe 3 + dataSchema2.columns["name"]!!.type shouldBe typeOf() + dataSchema2.columns["surname"]!!.type shouldBe typeOf() + dataSchema2.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false) + df3.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df3.schema().columns["surname"]!!.type shouldBe typeOf() + df3.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSQLQuery` method + + // start testing `readResultSet` method + + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> + @Language("SQL") + val selectStatement = "SELECT * FROM TestTable1" + + st.executeQuery(selectStatement).use { rs -> + // ith default inferNullability: Boolean = true + val df4 = DataFrame.readResultSet(rs, MsSql) + df4.schema().columns["id"]!!.type shouldBe typeOf() + df4.schema().columns["name"]!!.type shouldBe typeOf() + df4.schema().columns["surname"]!!.type shouldBe typeOf() + df4.schema().columns["age"]!!.type shouldBe typeOf() + + rs.beforeFirst() + + val dataSchema3 = DataFrame.getSchemaForResultSet(rs, MsSql) + dataSchema3.columns.size shouldBe 4 + dataSchema3.columns["id"]!!.type shouldBe typeOf() + dataSchema3.columns["name"]!!.type shouldBe typeOf() + dataSchema3.columns["surname"]!!.type shouldBe typeOf() + dataSchema3.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + rs.beforeFirst() + + val df5 = DataFrame.readResultSet(rs, MsSql, inferNullability = false) + df5.schema().columns["id"]!!.type shouldBe typeOf() + df5.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df5.schema().columns["surname"]!!.type shouldBe typeOf() + df5.schema().columns["age"]!!.type shouldBe typeOf() + } + } + // end testing `readResultSet` method + + connection.createStatement().execute("DROP TABLE TestTable1") + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mysqlH2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mysqlH2Test.kt new file mode 100644 index 0000000000..b37892503d --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/mysqlH2Test.kt @@ -0,0 +1,410 @@ +package org.jetbrains.kotlinx.dataframe.io.h2 + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.add +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import kotlin.reflect.typeOf + +// NOTE: the names of testing databases should be different to avoid collisions and should not contain the system names itself +private const val URL = "jdbc:h2:mem:test2;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_LOWER=TRUE" + +@DataSchema +interface Table1MySql { + val id: Int + val bitCol: Boolean + val tinyintcol: Int + val smallintcol: Int + val mediumintcol: Int + val mediumintunsignedcol: Int + val integercol: Int + val intcol: Int + val integerunsignedcol: Long + val bigintcol: Long + val floatcol: Float + val doublecol: Double + val decimalcol: BigDecimal + val datecol: String + val datetimecol: String + val timestampcol: String + val timecol: String + val yearcol: String + val varcharcol: String + val charcol: String + val binarycol: ByteArray + val varbinarycol: ByteArray + val tinyblobcol: ByteArray + val blobcol: ByteArray + val mediumblobcol: ByteArray + val longblobcol: ByteArray + val textcol: String + val mediumtextcol: String + val longtextcol: String + val enumcol: String + val setcol: Char +} + +@DataSchema +interface Table2MySql { + val id: Int + val bitcol: Boolean? + val tinyintcol: Int? + val smallintcol: Int? + val mediumintcol: Int? + val mediumintUnsignedcol: Int? + val integercol: Int? + val intcol: Int? + val integerUnsignedcol: Long? + val bigintcol: Long? + val floatcol: Float? + val doublecol: Double? + val decimalcol: Double? + val datecol: String? + val datetimecol: String? + val timestampcol: String? + val timecol: String? + val yearcol: String? + val varcharcol: String? + val charcol: String? + val binarycol: ByteArray? + val varbinarycol: ByteArray? + val tinyblobcol: ByteArray? + val blobcol: ByteArray? + val mediumblobcol: ByteArray? + val longblobcol: ByteArray? + val textcol: String? + val mediumtextcol: String? + val longtextcol: String? + val enumcol: String? + val setcol: Char? + val jsoncol: String? +} + +@DataSchema +interface Table3MySql { + val id: Int + val enumcol: String +} + +class MySqlH2Test { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL) + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE IF NOT EXISTS table1 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT NOT NULL, + tinyintCol TINYINT NOT NULL, + smallintCol SMALLINT NOT NULL, + mediumintCol MEDIUMINT NOT NULL, + mediumintUnsignedCol MEDIUMINT UNSIGNED NOT NULL, + integerCol INTEGER NOT NULL, + intCol INT NOT NULL, + integerUnsignedCol INTEGER UNSIGNED NOT NULL, + bigintCol BIGINT NOT NULL, + floatCol FLOAT NOT NULL, + doubleCol DOUBLE NOT NULL, + decimalCol DECIMAL NOT NULL, + dateCol DATE NOT NULL, + datetimeCol DATETIME NOT NULL, + timestampCol TIMESTAMP NOT NULL, + timeCol TIME NOT NULL, + yearCol YEAR NOT NULL, + varcharCol VARCHAR(255) NOT NULL, + charCol CHAR(10) NOT NULL, + binaryCol BINARY(64) NOT NULL, + varbinaryCol VARBINARY(128) NOT NULL, + tinyblobCol TINYBLOB NOT NULL, + blobCol BLOB NOT NULL, + mediumblobCol MEDIUMBLOB NOT NULL , + longblobCol LONGBLOB NOT NULL, + textCol TEXT NOT NULL, + mediumtextCol MEDIUMTEXT NOT NULL, + longtextCol LONGTEXT NOT NULL, + enumCol ENUM('Value1', 'Value2', 'Value3') NOT NULL, + data JSON + ) + """ + + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val createTableQuery2 = """ + CREATE TABLE IF NOT EXISTS table2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT, + tinyintCol TINYINT, + smallintCol SMALLINT, + mediumintCol MEDIUMINT, + mediumintUnsignedCol MEDIUMINT UNSIGNED, + integerCol INTEGER, + intCol INT, + integerUnsignedCol INTEGER UNSIGNED, + bigintCol BIGINT, + floatCol FLOAT, + doubleCol DOUBLE, + decimalCol DECIMAL, + dateCol DATE, + datetimeCol DATETIME, + timestampCol TIMESTAMP, + timeCol TIME, + yearCol YEAR, + varcharCol VARCHAR(255), + charCol CHAR(10), + binaryCol BINARY(64), + varbinaryCol VARBINARY(128), + tinyblobCol TINYBLOB, + blobCol BLOB, + mediumblobCol MEDIUMBLOB, + longblobCol LONGBLOB, + textCol TEXT, + mediumtextCol MEDIUMTEXT, + longtextCol LONGTEXT, + enumCol ENUM('Value1', 'Value2', 'Value3'), + data JSON + ) + """ + + connection.createStatement().execute( + createTableQuery2.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO table1 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, data + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + @Language("SQL") + val insertData2 = """ + INSERT INTO table2 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, data + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + connection.prepareStatement(insertData1).use { st -> + // Insert data into table1 + for (i in 1..3) { + st.setBoolean(1, true) + st.setByte(2, i.toByte()) + st.setShort(3, (i * 10).toShort()) + st.setInt(4, i * 100) + st.setInt(5, i * 100) + st.setInt(6, i * 100) + st.setInt(7, i * 100) + st.setInt(8, i * 100) + st.setInt(9, i * 100) + st.setFloat(10, i * 10.0f) + st.setDouble(11, i * 10.0) + st.setBigDecimal(12, BigDecimal(i * 10)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, "textValue$i") + st.setString(27, "mediumtextValue$i") + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, "{\"key\": \"value\"}") + st.executeUpdate() + } + } + + connection.prepareStatement(insertData2).use { st -> + // Insert data into table2 + for (i in 1..3) { + st.setBoolean(1, false) + st.setByte(2, (i * 2).toByte()) + st.setShort(3, (i * 20).toShort()) + st.setInt(4, i * 200) + st.setInt(5, i * 200) + st.setInt(6, i * 200) + st.setInt(7, i * 200) + st.setInt(8, i * 200) + st.setInt(9, i * 200) + st.setFloat(10, i * 20.0f) + st.setDouble(11, i * 20.0) + st.setBigDecimal(12, BigDecimal(i * 20)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, null) + st.setString(27, null) + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, "{\"key\": \"value\"}") + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `basic test for reading sql tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + val result = df1.filter { it[Table1MySql::id] == 1 } + result[0][26] shouldBe "textValue1" + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["textcol"]!!.type shouldBe typeOf() + + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + val result2 = df2.filter { it[Table2MySql::id] == 1 } + result2[0][26] shouldBe null + + val schema2 = DataFrame.getSchemaForSqlTable(connection, "table2") + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["textcol"]!!.type shouldBe typeOf() + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + t1.id, + t1.enumCol, + FROM table1 t1 + JOIN table2 t2 ON t1.id = t2.id + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + val result = df.filter { it[Table3MySql::id] == 1 } + result[0][1] shouldBe "Value1" + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["enumcol"]!!.type shouldBe typeOf() + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection).values.toList() + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1MySql::integercol] > 100 }.rowsCount() shouldBe 2 + table1Df[0][11] shouldBe 10.0 + table1Df[0][26] shouldBe "textValue1" + + val table2Df = dataframes[1].cast() + + table2Df.rowsCount() shouldBe 3 + table2Df.filter { it[Table2MySql::integercol] != null && it[Table2MySql::integercol]!! > 400 } + .rowsCount() shouldBe 1 + table2Df[0][11] shouldBe 20.0 + table2Df[0][26] shouldBe null + } + + @Test + fun `reading numeric types`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + + val result = df1.select("tinyintcol").add("tinyintcol2") { it[Table1MySql::tinyintcol] } + + result[0][1] shouldBe 1.toByte() + + val result1 = df1.select("smallintcol") + .add("smallintcol2") { it[Table1MySql::smallintcol] } + + result1[0][1] shouldBe 10.toShort() + + val result2 = df1.select("mediumintcol") + .add("mediumintcol2") { it[Table1MySql::mediumintcol] } + + result2[0][1] shouldBe 100 + + val result3 = df1.select("mediumintunsignedcol") + .add("mediumintunsignedcol2") { it[Table1MySql::mediumintunsignedcol] } + + result3[0][1] shouldBe 100 + + val result5 = df1.select("bigintcol") + .add("bigintcol2") { it[Table1MySql::bigintcol] } + + result5[0][1] shouldBe 100 + + val result7 = df1.select("doublecol") + .add("doublecol2") { it[Table1MySql::doublecol] } + + result7[0][1] shouldBe 10.0 + + val result8 = df1.select("decimalcol") + .add("decimalcol2") { it[Table1MySql::decimalcol] } + + result8[0][1] shouldBe BigDecimal("10") + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + + schema.columns["tinyintcol"]!!.type shouldBe typeOf() + schema.columns["smallintcol"]!!.type shouldBe typeOf() + schema.columns["mediumintcol"]!!.type shouldBe typeOf() + schema.columns["mediumintunsignedcol"]!!.type shouldBe typeOf() + schema.columns["bigintcol"]!!.type shouldBe typeOf() + schema.columns["floatcol"]!!.type shouldBe typeOf() + schema.columns["doublecol"]!!.type shouldBe typeOf() + schema.columns["decimalcol"]!!.type shouldBe typeOf() + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/postgresH2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/postgresH2Test.kt new file mode 100644 index 0000000000..0dae880ae9 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/postgresH2Test.kt @@ -0,0 +1,289 @@ +package org.jetbrains.kotlinx.dataframe.io.h2 + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.add +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import java.util.UUID +import kotlin.reflect.typeOf + +private const val URL = "jdbc:h2:mem:test3;DB_CLOSE_DELAY=-1;MODE=PostgreSQL;DATABASE_TO_LOWER=TRUE;DEFAULT_NULL_ORDERING=HIGH" + +@DataSchema +interface Table1 { + val id: Int + val bigintcol: Long + val smallintcol: Int + val bigserialcol: Long + val booleancol: Boolean + val byteacol: ByteArray + val charactercol: String + val characterncol: String + val charcol: String + val datecol: java.sql.Date + val doublecol: Double + val integercol: Int? + val jsoncol: String + val jsonbcol: String +} + +@DataSchema +interface Table2 { + val id: Int + val moneycol: String + val numericcol: BigDecimal + val realcol: Float + val smallintcol: Int + val serialcol: Int + val textcol: String? + val timecol: String + val timewithzonecol: String + val timestampcol: String + val timestampwithzonecol: String + val uuidcol: String +} + +@DataSchema +interface ViewTable { + val id: Int + val bigintcol: Long + val textCol: String? +} + +class PostgresH2Test { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL) + + @Language("SQL") + val createTableStatement = """ + CREATE TABLE IF NOT EXISTS table1 ( + id serial PRIMARY KEY, + bigintCol bigint not null, + smallintCol smallint not null, + bigserialCol bigserial not null, + booleanCol boolean not null, + byteaCol bytea not null, + characterCol character not null, + characterNCol character(10) not null, + charCol char not null, + dateCol date not null, + doubleCol double precision not null, + integerCol integer + ) + """ + connection.createStatement().execute( + createTableStatement.trimIndent() + ) + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE IF NOT EXISTS table2 ( + id serial PRIMARY KEY, + moneyCol money not null, + numericCol numeric not null, + realCol real not null, + smallintCol smallint not null, + serialCol serial not null, + textCol text, + timeCol time not null, + timeWithZoneCol time with time zone not null, + timestampCol timestamp not null, + timestampWithZoneCol timestamp with time zone not null, + uuidCol uuid not null + ) + """ + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO table1 ( + bigintCol, smallintCol, bigserialCol, booleanCol, + byteaCol, characterCol, characterNCol, charCol, + dateCol, doubleCol, + integerCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + + @Language("SQL") + val insertData2 = """ + INSERT INTO table2 ( + moneyCol, numericCol, + realCol, smallintCol, + serialCol, textCol, timeCol, + timeWithZoneCol, timestampCol, timestampWithZoneCol, + uuidCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + + connection.prepareStatement(insertData1).use { st -> + // Insert data into table1 + for (i in 1..3) { + st.setLong(1, i * 1000L) + st.setShort(2, 11.toShort()) + st.setLong(3, 1000000000L + i) + st.setBoolean(4, i % 2 == 1) + st.setBytes(5, byteArrayOf(1, 2, 3)) + st.setString(6, "A") + st.setString(7, "Hello") + st.setString(8, "A") + st.setDate(9, java.sql.Date.valueOf("2023-08-01")) + st.setDouble(10, 12.34) + st.setInt(11, 12345 * i) + st.executeUpdate() + } + } + + connection.prepareStatement(insertData2).use { st -> + // Insert data into table2 + for (i in 1..3) { + st.setBigDecimal(1, BigDecimal("123.45")) + st.setBigDecimal(2, BigDecimal("12.34")) + st.setFloat(3, 12.34f) + st.setInt(4, 1000 + i) + st.setInt(5, 1000000 + i) + st.setString(6, null) + st.setTime(7, java.sql.Time.valueOf("12:34:56")) + st.setTimestamp(8, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(9, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(10, java.sql.Timestamp(System.currentTimeMillis())) + st.setObject(11, UUID.randomUUID(), java.sql.Types.OTHER) + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `read from tables`() { + val tableName1 = "table1" + val df1 = DataFrame.readSqlTable(connection, tableName1).cast() + val result = df1.filter { it[Table1::id] == 1 } + + result[0][0] shouldBe 1 + result[0][8] shouldBe "A" + + val schema = DataFrame.getSchemaForSqlTable(connection, tableName1) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["integercol"]!!.type shouldBe typeOf() + schema.columns["smallintcol"]!!.type shouldBe typeOf() + + val tableName2 = "table2" + val df2 = DataFrame.readSqlTable(connection, tableName2).cast() + val result2 = df2.filter { it[Table2::id] == 1 } + result2[0][4] shouldBe 1001 + + val schema2 = DataFrame.getSchemaForSqlTable(connection, tableName2) + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["textcol"]!!.type shouldBe typeOf() + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + t1.id, + t1.bigintCol, + t2.textCol + FROM table1 t1 + JOIN table2 t2 ON t1.id = t2.id + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + val result = df.filter { it[ViewTable::id] == 1 } + result[0][2] shouldBe null + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["bigintcol"]!!.type shouldBe typeOf() + schema.columns["textcol"]!!.type shouldBe typeOf() + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection).values.toList() + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1::integercol] != null && it[Table1::integercol]!! > 12345 }.rowsCount() shouldBe 2 + table1Df[0][1] shouldBe 1000L + table1Df[0][2] shouldBe 11 + + val table2Df = dataframes[1].cast() + + table2Df.rowsCount() shouldBe 3 + table2Df.filter { it[Table2::realcol] == 12.34f } + .rowsCount() shouldBe 3 + table2Df[0][4] shouldBe 1001 + } + + @Test + fun `read columns of different types to check type mapping`() { + val tableName1 = "table1" + val df1 = DataFrame.readSqlTable(connection, tableName1).cast() + val result = df1.select("smallintcol").add("smallintcol2") {it[Table1::smallintcol]} + result[0][1] shouldBe 11 + + val result1 = df1.select("bigserialcol").add("bigserialcol2") {it[Table1::bigserialcol]} + result1[0][1] shouldBe 1000000001L + + val result2 = df1.select("doublecol").add("doublecol2") {it[Table1::doublecol]} + result2[0][1] shouldBe 12.34 + + val tableName2 = "table2" + val df2 = DataFrame.readSqlTable(connection, tableName2).cast() + + val result4 = df2.select("numericcol").add("numericcol2") {it[Table2::numericcol]} + result4[0][1] shouldBe BigDecimal("12.34") + + val result5 = df2.select("realcol").add("realcol2") {it[Table2::realcol]} + result5[0][1] shouldBe 12.34f + + val result8 = df2.select("serialcol").add("serialcol2") {it[Table2::serialcol]} + result8[0][1] shouldBe 1000001 + + val schema = DataFrame.getSchemaForSqlTable(connection, tableName1) + schema.columns["smallintcol"]!!.type shouldBe typeOf() + schema.columns["bigserialcol"]!!.type shouldBe typeOf() + schema.columns["doublecol"]!!.type shouldBe typeOf() + + val schema1 = DataFrame.getSchemaForSqlTable(connection, tableName2) + schema1.columns["numericcol"]!!.type shouldBe typeOf() + schema1.columns["realcol"]!!.type shouldBe typeOf() + schema1.columns["serialcol"]!!.type shouldBe typeOf() + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/imdbTest.kt similarity index 93% rename from dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt rename to dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/imdbTest.kt index acff13e08e..517e2089ac 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/imdbTest.kt @@ -1,10 +1,14 @@ -package org.jetbrains.kotlinx.dataframe.io +package org.jetbrains.kotlinx.dataframe.io.local import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable import org.junit.Ignore import org.junit.Test import java.sql.DriverManager diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mariadbTest.kt similarity index 98% rename from dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt rename to dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mariadbTest.kt index a6720ffa80..1c0f957211 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mariadbTest.kt @@ -1,4 +1,4 @@ -package org.jetbrains.kotlinx.dataframe.io +package org.jetbrains.kotlinx.dataframe.io.local import io.kotest.matchers.shouldBe import org.intellij.lang.annotations.Language @@ -8,6 +8,11 @@ import org.jetbrains.kotlinx.dataframe.api.add import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Ignore diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mssqlTest.kt similarity index 95% rename from dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt rename to dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mssqlTest.kt index fc0fb8c0de..859fa6f745 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mssqlTest.kt @@ -1,12 +1,20 @@ -package org.jetbrains.kotlinx.dataframe.io +package org.jetbrains.kotlinx.dataframe.io.local import io.kotest.matchers.shouldBe import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema -import org.jetbrains.kotlinx.dataframe.api.* -import org.jetbrains.kotlinx.dataframe.io.JdbcTest.Companion -import org.jetbrains.kotlinx.dataframe.io.db.H2 +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.schema +import org.jetbrains.kotlinx.dataframe.io.db.MsSql +import org.jetbrains.kotlinx.dataframe.io.getSchemaForResultSet +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readResultSet +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Ignore @@ -16,7 +24,8 @@ import java.sql.Connection import java.sql.DriverManager import java.sql.ResultSet import java.sql.SQLException -import java.util.* +import java.util.Date +import java.util.UUID import kotlin.reflect.typeOf private const val URL = "jdbc:sqlserver://localhost:1433;encrypt=true;trustServerCertificate=true" @@ -369,7 +378,7 @@ class MSSQLTest { st.executeQuery(selectStatement).use { rs -> // ith default inferNullability: Boolean = true - val df4 = DataFrame.readResultSet(rs, H2) + val df4 = DataFrame.readResultSet(rs, MsSql) df4.schema().columns["id"]!!.type shouldBe typeOf() df4.schema().columns["name"]!!.type shouldBe typeOf() df4.schema().columns["surname"]!!.type shouldBe typeOf() @@ -377,7 +386,7 @@ class MSSQLTest { rs.beforeFirst() - val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2) + val dataSchema3 = DataFrame.getSchemaForResultSet(rs, MsSql) dataSchema3.columns.size shouldBe 4 dataSchema3.columns["id"]!!.type shouldBe typeOf() dataSchema3.columns["name"]!!.type shouldBe typeOf() @@ -387,7 +396,7 @@ class MSSQLTest { // with inferNullability: Boolean = false rs.beforeFirst() - val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false) + val df5 = DataFrame.readResultSet(rs, MsSql, inferNullability = false) df5.schema().columns["id"]!!.type shouldBe typeOf() df5.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls df5.schema().columns["surname"]!!.type shouldBe typeOf() diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mysqlTest.kt similarity index 98% rename from dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt rename to dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mysqlTest.kt index 892ee080bc..43b86fcfdd 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/mysqlTest.kt @@ -1,4 +1,4 @@ -package org.jetbrains.kotlinx.dataframe.io +package org.jetbrains.kotlinx.dataframe.io.local import io.kotest.matchers.shouldBe import org.intellij.lang.annotations.Language @@ -8,6 +8,11 @@ import org.jetbrains.kotlinx.dataframe.api.add import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Ignore diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/postgresTest.kt similarity index 97% rename from dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt rename to dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/postgresTest.kt index 40bc898a3c..8a7ab899a9 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/postgresTest.kt @@ -1,4 +1,4 @@ -package org.jetbrains.kotlinx.dataframe.io +package org.jetbrains.kotlinx.dataframe.io.local import io.kotest.matchers.shouldBe import org.intellij.lang.annotations.Language @@ -8,6 +8,11 @@ import org.jetbrains.kotlinx.dataframe.api.add import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Ignore diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt index d5c43e1cb4..1bcd7dfe83 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt @@ -8,14 +8,13 @@ import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter import org.junit.AfterClass import org.junit.BeforeClass -import org.junit.Ignore import org.junit.Test import java.sql.Connection import java.sql.DriverManager import java.sql.SQLException import kotlin.reflect.typeOf -private const val DATABASE_URL = "jdbc:sqlite:" +private const val DATABASE_URL = "jdbc:sqlite::memory:" @DataSchema interface CustomerSQLite { @@ -48,7 +47,6 @@ interface CustomerOrderSQLite { val orderDetails: ByteArray? } -@Ignore class SqliteTest { companion object { private lateinit var connection: Connection diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index e459f1cf94..ab1b17eac9 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -35,6 +35,7 @@ mssql = "12.6.1.jre11" mysql = "8.3.0" postgresql = "42.7.2" sqlite = "3.45.1.0" +jtsCore = "1.18.1" kotlinDatetime = "0.5.0" openapi = "2.1.20" kotlinLogging = "6.0.3" @@ -80,6 +81,7 @@ mssql = { group = "com.microsoft.sqlserver", name = "mssql-jdbc", version.ref = mysql = { group = "com.mysql", name = "mysql-connector-j", version.ref = "mysql" } postgresql = { group = "org.postgresql", name = "postgresql", version.ref = "postgresql" } sqlite = { group = "org.xerial", name = "sqlite-jdbc", version.ref = "sqlite" } +jts = { group = "org.locationtech.jts", name = "jts-core", version.ref = "jtsCore" } poi-ooxml = { group = "org.apache.poi", name = "poi-ooxml", version.ref = "poi" } kotlin-datetimeJvm = { group = "org.jetbrains.kotlinx", name = "kotlinx-datetime-jvm", version.ref = "kotlinDatetime" } diff --git a/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt b/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt index dd1aec2d1a..0cae08eef1 100644 --- a/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt +++ b/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt @@ -170,6 +170,7 @@ class DataSchemaGenerator( val url = importStatement.dataSource.pathRepresentation // Force classloading + // TODO: probably will not work for the H2 Class.forName(driverClassNameFromUrl(url)) var userName = importStatement.jdbcOptions.user