diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index ab5946164d..b9778395c8 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -290,33 +290,35 @@ public fun DataFrame.Companion.readResultSet( } /** - * Reads all tables from the given database using the provided database configuration and limit. + * Reads all non-system tables from a database and returns them + * as a map of SQL tables and corresponding dataframes using the provided database configuration and limit. * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [limit] the maximum number of rows to read from each table. * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. * @param [inferNullability] indicates how the column nullability should be inferred. - * @return a list of [AnyFrame] objects representing the non-system tables from the database. + * @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database. */ public fun DataFrame.Companion.readAllSqlTables( dbConfig: DatabaseConfiguration, catalogue: String? = null, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, -): List { +): Map { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> return readAllSqlTables(connection, catalogue, limit, inferNullability) } } /** - * Reads all non-system tables from a database and returns them as a list of data frames. + * Reads all non-system tables from a database and returns them + * as a map of SQL tables and corresponding dataframes. * * @param [connection] the database connection to read tables from. * @param [limit] the maximum number of rows to read from each table. * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. * @param [inferNullability] indicates how the column nullability should be inferred. - * @return a list of [AnyFrame] objects representing the non-system tables from the database. + * @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database. * * @see DriverManager.getConnection */ @@ -325,20 +327,20 @@ public fun DataFrame.Companion.readAllSqlTables( catalogue: String? = null, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, -): List { +): Map { val metaData = connection.metaData val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - // exclude a system and other tables without data, but it looks like it supported badly for many databases + // exclude a system and other tables without data, but it looks like it is supported badly for many databases val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE")) - val dataFrames = mutableListOf() + val dataFrames = mutableMapOf() while (tables.next()) { val table = dbType.buildTableMetadata(tables) if (!dbType.isSystemTable(table)) { - // we filter her second time because of specific logic with SQLite and possible issues with future databases + // we filter here a second time because of specific logic with SQLite and possible issues with future databases val tableName = when { catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}" catalogue != null && table.schemaName == null -> "$catalogue.${table.name}" @@ -351,7 +353,7 @@ public fun DataFrame.Companion.readAllSqlTables( logger.debug { "Reading table: $tableName" } val dataFrame = readSqlTable(connection, tableName, limit, inferNullability) - dataFrames += dataFrame + dataFrames += tableName to dataFrame logger.debug { "Finished reading table: $tableName" } } } @@ -474,24 +476,24 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, conne } /** - * Retrieves the schema of all non-system tables in the database using the provided database configuration. + * Retrieves the schemas of all non-system tables in the database using the provided database configuration. * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. - * @return a list of [DataFrameSchema] objects representing the schema of each non-system table. + * @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table. */ -public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List { +public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> return getSchemaForAllSqlTables(connection) } } /** - * Retrieves the schema of all non-system tables in the database using the provided database connection. + * Retrieves the schemas of all non-system tables in the database using the provided database connection. * * @param [connection] the database connection. - * @return a list of [DataFrameSchema] objects representing the schema of each non-system table. + * @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table. */ -public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): List { +public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map { val metaData = connection.metaData val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) @@ -500,14 +502,15 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): // exclude a system and other tables without data val tables = metaData.getTables(null, null, null, tableTypes) - val dataFrameSchemas = mutableListOf() + val dataFrameSchemas = mutableMapOf() while (tables.next()) { val jdbcTable = dbType.buildTableMetadata(tables) if (!dbType.isSystemTable(jdbcTable)) { - // we filter her second time because of specific logic with SQLite and possible issues with future databases - val dataFrameSchema = getSchemaForSqlTable(connection, jdbcTable.name) - dataFrameSchemas += dataFrameSchema + // we filter her a second time because of specific logic with SQLite and possible issues with future databases + val tableName = jdbcTable.name + val dataFrameSchema = getSchemaForSqlTable(connection, tableName) + dataFrameSchemas += tableName to dataFrameSchema } } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt index 4b57818184..4b74017894 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt @@ -597,7 +597,11 @@ class JdbcTest { @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection) + val dataFrameMap = DataFrame.readAllSqlTables(connection) + dataFrameMap.containsKey("Customer") shouldBe true + dataFrameMap.containsKey("Sale") shouldBe true + + val dataframes = dataFrameMap.values.toList() val customerDf = dataframes[0].cast() @@ -611,7 +615,7 @@ class JdbcTest { saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1) + val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1).values.toList() val customerDf1 = dataframes1[0].cast() @@ -625,7 +629,11 @@ class JdbcTest { saleDf1.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - val dataSchemas = DataFrame.getSchemaForAllSqlTables(connection) + val dataFrameSchemaMap = DataFrame.getSchemaForAllSqlTables(connection) + dataFrameSchemaMap.containsKey("Customer") shouldBe true + dataFrameSchemaMap.containsKey("Sale") shouldBe true + + val dataSchemas = dataFrameSchemaMap.values.toList() val customerDataSchema = dataSchemas[0] customerDataSchema.columns.size shouldBe 3 @@ -637,7 +645,7 @@ class JdbcTest { saleDataSchema.columns["amount"]!!.type shouldBe typeOf() val dbConfig = DatabaseConfiguration(url = URL) - val dataframes2 = DataFrame.readAllSqlTables(dbConfig) + val dataframes2 = DataFrame.readAllSqlTables(dbConfig).values.toList() val customerDf2 = dataframes2[0].cast() @@ -651,7 +659,7 @@ class JdbcTest { saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1) + val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1).values.toList() val customerDf3 = dataframes3[0].cast() @@ -665,7 +673,7 @@ class JdbcTest { saleDf3.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig) + val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig).values.toList() val customerDataSchema1 = dataSchemas1[0] customerDataSchema1.columns.size shouldBe 3 diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt index 16ae816d15..a6720ffa80 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt @@ -370,7 +370,7 @@ class MariadbTest { @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000) + val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000).values.toList() val table1Df = dataframes[0].cast() diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt index 7132b29290..fc0fb8c0de 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -277,7 +277,7 @@ class MSSQLTest { @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4) + val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4).values.toList() val table1Df = dataframes[0].cast() diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt index 651e2292a8..892ee080bc 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt @@ -370,7 +370,7 @@ class MySqlTest { @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection) + val dataframes = DataFrame.readAllSqlTables(connection).values.toList() val table1Df = dataframes[0].cast() diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt index a4b2ee8a3c..40bc898a3c 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt @@ -298,7 +298,7 @@ class PostgresTest { @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection) + val dataframes = DataFrame.readAllSqlTables(connection).values.toList() val table1Df = dataframes[0].cast() diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt index ebe550d750..d5c43e1cb4 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt @@ -193,7 +193,7 @@ class SqliteTest { @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection) + val dataframes = DataFrame.readAllSqlTables(connection).values.toList() val customerDf = dataframes[0].cast() diff --git a/docs/StardustDocs/topics/readSqlDatabases.md b/docs/StardustDocs/topics/readSqlDatabases.md index dcb890ba15..8dd330d46e 100644 --- a/docs/StardustDocs/topics/readSqlDatabases.md +++ b/docs/StardustDocs/topics/readSqlDatabases.md @@ -59,7 +59,7 @@ In the second, be sure that you can establish a connection to the database. For this, usually, you need to have three things: a URL to a database, a username and a password. -Call one of the following functions to obtain data from a database and transform it to the dataframe. +Call one of the following functions to collect data from a database and transform it to the dataframe. For example, if you have a local PostgreSQL database named as `testDatabase` with table `Customer`, you could read first 100 rows and print the data just copying the code below: @@ -105,7 +105,7 @@ Next, import `Kotlin DataFrame` library in the cell below. **NOTE:** The order of cell execution is important, the dataframe library is waiting for a JDBC driver to force classloading. -Find full example Notebook [here](https://github.com/zaleslaw/KotlinDataFrame-SQL-Examples/blob/master/notebooks/imdb.ipynb). +Find a full example Notebook [here](https://github.com/zaleslaw/KotlinDataFrame-SQL-Examples/blob/master/notebooks/imdb.ipynb). ## Reading Specific Tables @@ -315,9 +315,9 @@ connection.close() These functions read all data from all tables in the connected database. Variants with a limit parameter restrict how many rows will be read from each table. -**readAllSqlTables(connection: Connection): List\** +**readAllSqlTables(connection: Connection): Map\** -Retrieves data from all the non-system tables in the SQL database and returns them as a list of AnyFrame objects. +Retrieves data from all the non-system tables in the SQL database and returns them as a map of table names to AnyFrame objects. The `dbConfig: DatabaseConfiguration` parameter represents the configuration for a database connection, created under the hood and managed by the library. Typically, it requires a URL, username and password. @@ -330,7 +330,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS val dataframes = DataFrame.readAllSqlTables(dbConfig) ``` -**readAllSqlTables(connection: Connection, limit: Int): List\** +**readAllSqlTables(connection: Connection, limit: Int): Map\** A variant of the previous function, but with an added `limit: Int` parameter that allows setting the maximum number of records to be read from each table. @@ -493,10 +493,10 @@ connection.close() These functions return a list of all [`DataFrameSchema`](schema.md) from all the non-system tables in the SQL database. They can be called with either a database configuration or a connection. -**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List\** +**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map\** This function retrieves the schema of all tables from an SQL database -and returns them as a list of [`DataFrameSchema`](schema.md). +and returns them as a map of table names to [`DataFrameSchema`](schema.md) objects. The `dbConfig: DatabaseConfiguration` parameter represents the configuration for a database connection, created under the hood and managed by the library. Typically, it requires a URL, username and password. @@ -509,7 +509,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS val schemas = DataFrame.getSchemaForAllSqlTables(dbConfig) ``` -**getSchemaForAllSqlTables(connection: Connection): List\** +**getSchemaForAllSqlTables(connection: Connection): Map\** This function retrieves the schema of all tables using a JDBC connection: `Connection` object and returns them as a list of [`DataFrameSchema`](schema.md).