Skip to content

Added inferNullability test for other databases #954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package org.jetbrains.kotlinx.dataframe.io

import io.kotest.matchers.shouldBe
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.io.db.MsSql
import java.sql.Connection
import java.sql.ResultSet
import kotlin.reflect.typeOf

internal fun inferNullability(connection: Connection) {
// prepare tables and data
@Language("SQL")
val createTestTable1Query = """
CREATE TABLE TestTable1 (
id INT PRIMARY KEY,
name VARCHAR(50),
surname VARCHAR(50),
age INT NOT NULL
)
"""

connection.createStatement().execute(createTestTable1Query)

connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")

// start testing `readSqlTable` method

// with default inferNullability: Boolean = true
val tableName = "TestTable1"
val df = DataFrame.readSqlTable(connection, tableName)
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 4
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df1.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSqlTable` method

// start testing `readSQLQuery` method

// ith default inferNullability: Boolean = true
@Language("SQL")
val sqlQuery =
"""
SELECT name, surname, age FROM TestTable1
""".trimIndent()

val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
df2.schema().columns["name"]!!.type shouldBe typeOf<String>()
df2.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df2.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery)
dataSchema2.columns.size shouldBe 3
dataSchema2.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false)
// this column changed a type because it doesn't contain nulls
df3.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSQLQuery` method

// start testing `readResultSet` method

connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
@Language("SQL")
val selectStatement = "SELECT * FROM TestTable1"

st.executeQuery(selectStatement).use { rs ->
// ith default inferNullability: Boolean = true
val df4 = DataFrame.readResultSet(rs, MsSql)
df4.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df4.schema().columns["name"]!!.type shouldBe typeOf<String>()
df4.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df4.schema().columns["age"]!!.type shouldBe typeOf<Int>()

rs.beforeFirst()

val dataSchema3 = DataFrame.getSchemaForResultSet(rs, MsSql)
dataSchema3.columns.size shouldBe 4
dataSchema3.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema3.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
rs.beforeFirst()

val df5 = DataFrame.readResultSet(rs, MsSql, inferNullability = false)
df5.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df5.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["age"]!!.type shouldBe typeOf<Int>()
}
}
// end testing `readResultSet` method

connection.createStatement().execute("DROP TABLE TestTable1")
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.api.select
import org.jetbrains.kotlinx.dataframe.io.DbConnectionConfig
import org.jetbrains.kotlinx.dataframe.io.db.H2
Expand All @@ -20,6 +19,7 @@ import org.jetbrains.kotlinx.dataframe.io.getSchemaForAllSqlTables
import org.jetbrains.kotlinx.dataframe.io.getSchemaForResultSet
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable
import org.jetbrains.kotlinx.dataframe.io.inferNullability
import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables
import org.jetbrains.kotlinx.dataframe.io.readDataFrame
import org.jetbrains.kotlinx.dataframe.io.readResultSet
Expand Down Expand Up @@ -841,128 +841,9 @@ class JdbcTest {
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
}

// TODO: add the same test for each particular database and refactor the scenario to the common test case
// https://github.com/Kotlin/dataframe/issues/688
@Test
fun `infer nullability`() {
// prepare tables and data
@Language("SQL")
val createTestTable1Query = """
CREATE TABLE TestTable1 (
id INT PRIMARY KEY,
name VARCHAR(50),
surname VARCHAR(50),
age INT NOT NULL
)
"""

connection.createStatement().execute(createTestTable1Query)

connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")

// start testing `readSqlTable` method

// with default inferNullability: Boolean = true
val tableName = "TestTable1"
val df = DataFrame.readSqlTable(connection, tableName)
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 4
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df1.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSqlTable` method

// start testing `readSQLQuery` method

// ith default inferNullability: Boolean = true
@Language("SQL")
val sqlQuery =
"""
SELECT name, surname, age FROM TestTable1
""".trimIndent()

val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
df2.schema().columns["name"]!!.type shouldBe typeOf<String>()
df2.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df2.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery)
dataSchema2.columns.size shouldBe 3
dataSchema2.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false)

// this column changed a type because it doesn't contain nulls
df3.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSQLQuery` method

// start testing `readResultSet` method

connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
@Language("SQL")
val selectStatement = "SELECT * FROM TestTable1"

st.executeQuery(selectStatement).use { rs ->
// ith default inferNullability: Boolean = true
val df4 = DataFrame.readResultSet(rs, H2(MySql))
df4.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df4.schema().columns["name"]!!.type shouldBe typeOf<String>()
df4.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df4.schema().columns["age"]!!.type shouldBe typeOf<Int>()

rs.beforeFirst()

val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2(MySql))
dataSchema3.columns.size shouldBe 4
dataSchema3.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema3.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
rs.beforeFirst()

val df5 = DataFrame.readResultSet(rs, H2(MySql), inferNullability = false)
df5.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df5.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["age"]!!.type shouldBe typeOf<Int>()
}
}
// end testing `readResultSet` method

connection.createStatement().execute("DROP TABLE TestTable1")
inferNullability(connection)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.api.select
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable
import org.jetbrains.kotlinx.dataframe.io.inferNullability
import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables
import org.jetbrains.kotlinx.dataframe.io.readSqlQuery
import org.jetbrains.kotlinx.dataframe.io.readSqlTable
Expand Down Expand Up @@ -417,4 +418,9 @@ class MariadbH2Test {
schema.columns["doublecol"]!!.type shouldBe typeOf<Double>()
schema.columns["decimalcol"]!!.type shouldBe typeOf<BigDecimal>()
}

@Test
fun `infer nullability`() {
inferNullability(connection)
}
}
Loading