diff --git a/dataframe-jdbc/api/dataframe-jdbc.api b/dataframe-jdbc/api/dataframe-jdbc.api index 94cb132607..552c069316 100644 --- a/dataframe-jdbc/api/dataframe-jdbc.api +++ b/dataframe-jdbc/api/dataframe-jdbc.api @@ -54,26 +54,26 @@ public final class org/jetbrains/kotlinx/dataframe/io/ReadJdbcKt { public static final fun readAllSqlTables (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Ljava/util/Map; public static synthetic fun readAllSqlTables$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Ljava/util/Map; public static synthetic fun readAllSqlTables$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Ljava/util/Map; - public static final fun readDataFrame (Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun readDataFrame (Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;Z)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun readDataFrame (Ljava/sql/ResultSet;Ljava/sql/Connection;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun readDataFrame (Ljava/sql/ResultSet;Lorg/jetbrains/kotlinx/dataframe/io/db/DbType;IZ)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun readDataFrame (Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static synthetic fun readDataFrame$default (Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun readDataFrame (Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;Z)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static synthetic fun readDataFrame$default (Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun readDataFrame$default (Ljava/sql/ResultSet;Ljava/sql/Connection;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun readDataFrame$default (Ljava/sql/ResultSet;Lorg/jetbrains/kotlinx/dataframe/io/db/DbType;IZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static synthetic fun readDataFrame$default (Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static synthetic fun readDataFrame$default (Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun readResultSet (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/ResultSet;Ljava/sql/Connection;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun readResultSet (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/ResultSet;Lorg/jetbrains/kotlinx/dataframe/io/db/DbType;IZ)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun readResultSet$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/ResultSet;Ljava/sql/Connection;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun readResultSet$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/ResultSet;Lorg/jetbrains/kotlinx/dataframe/io/db/DbType;IZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun readSqlQuery (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun readSqlQuery (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static synthetic fun readSqlQuery$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static synthetic fun readSqlQuery$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun readSqlTable (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun readSqlTable (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static synthetic fun readSqlTable$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static synthetic fun readSqlTable$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun readSqlQuery (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;Z)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun readSqlQuery (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;Z)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static synthetic fun readSqlQuery$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static synthetic fun readSqlQuery$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun readSqlTable (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;Z)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun readSqlTable (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;Z)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static synthetic fun readSqlTable$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Ljava/sql/Connection;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static synthetic fun readSqlTable$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame$Companion;Lorg/jetbrains/kotlinx/dataframe/io/DbConnectionConfig;Ljava/lang/String;IZLorg/jetbrains/kotlinx/dataframe/io/db/DbType;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; } public final class org/jetbrains/kotlinx/dataframe/io/TableColumnMetadata { diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt index 369250c938..45487b5586 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt @@ -25,8 +25,8 @@ public fun extractDBTypeFromConnection(connection: Connection): DbType { // works only for H2 version 2 val modeQuery = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'" var mode = "" - connection.createStatement().use { st -> - st.executeQuery(modeQuery).use { rs -> + connection.prepareStatement(modeQuery).use { st -> + st.executeQuery().use { rs -> if (rs.next()) { mode = rs.getString("SETTING_VALUE") logger.debug { "Fetched H2 DB mode: $mode" } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index bb47209d10..7bd110ba73 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -51,19 +51,25 @@ private val logger = KotlinLogging.logger {} private const val DEFAULT_LIMIT = Int.MIN_VALUE /** - * Constant variable indicating the start of an SQL read query. - * The value of this variable is "SELECT". - */ -private const val START_OF_READ_SQL_QUERY = "SELECT" - -/** - * Constant representing the separator used to separate multiple SQL queries. + * A regular expression defining the valid pattern for SQL table names. + * + * This pattern enforces that table names must: + * - Contain only Unicode letters, Unicode digits, or underscores. + * - Optionally be segmented by dots to indicate schema and table separation. + * + * It ensures compatibility with most SQL database naming conventions, thus minimizing risks of invalid names + * or injection vulnerabilities. + * + * Example of valid table names: + * - `my_table` + * - `schema1.table2` * - * This separator is used when multiple SQL queries need to be executed together. - * Each query should be separated by this separator to indicate the end of one query - * and the start of the next query. + * Example of invalid table names: + * - `my-table` (contains a dash) + * - `table!name` (contains special characters) + * - `.startWithDot` (cannot start with a dot) */ -private const val MULTIPLE_SQL_QUERY_SEPARATOR = ";" +private const val TABLE_NAME_VALID_PATTERN = "^[\\p{L}\\p{N}_]+(\\.[\\p{L}\\p{N}_]+)*$" /** * Represents a column in a database table to keep all required meta-information. @@ -115,6 +121,8 @@ public data class DbConnectionConfig(val url: String, val user: String = "", val * @param [inferNullability] indicates how the column nullability should be inferred. * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, * in that case the [dbType] will be recognized from the [dbConfig]. + * @param [strictValidation] if `true`, the method validates that the provided table name is in a valid format. + * Default is `true` for strict validation. * @return the DataFrame containing the data from the SQL table. */ public fun DataFrame.Companion.readSqlTable( @@ -123,9 +131,10 @@ public fun DataFrame.Companion.readSqlTable( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, dbType: DbType? = null, + strictValidation: Boolean = true, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlTable(connection, tableName, limit, inferNullability, dbType) + return readSqlTable(connection, tableName, limit, inferNullability, dbType, strictValidation) } } @@ -138,6 +147,8 @@ public fun DataFrame.Companion.readSqlTable( * @param [inferNullability] indicates how the column nullability should be inferred. * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, * in that case the [dbType] will be recognized from the [connection]. + * @param [strictValidation] if `true`, the method validates that the provided table name is in a valid format. + * Default is `true` for strict validation. * @return the DataFrame containing the data from the SQL table. * * @see DriverManager.getConnection @@ -148,7 +159,16 @@ public fun DataFrame.Companion.readSqlTable( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, dbType: DbType? = null, + strictValidation: Boolean = true, ): AnyFrame { + if (strictValidation) { + require(isValidTableName(tableName)) { + "The provided table name '$tableName' is invalid. Please ensure it matches a valid table name in the database schema." + } + } else { + logger.warn { "Strict validation is disabled. Make sure the table name '$tableName' is correct." } + } + val url = connection.metaData.url val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) @@ -158,10 +178,10 @@ public fun DataFrame.Companion.readSqlTable( "SELECT * FROM $tableName" } - connection.createStatement().use { st -> + connection.prepareStatement(selectAllQuery).use { st -> logger.debug { "Connection with url:$url is established successfully." } - st.executeQuery(selectAllQuery).use { rs -> + st.executeQuery().use { rs -> val tableColumns = getTableColumnsMetadata(rs) return fetchAndConvertDataFromResultSet(tableColumns, rs, determinedDbType, limit, inferNullability) } @@ -180,17 +200,21 @@ public fun DataFrame.Companion.readSqlTable( * @param [inferNullability] indicates how the column nullability should be inferred. * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, * in that case the [dbType] will be recognized from the [dbConfig]. + * @param [strictValidation] if `true`, the method validates that the provided query is in a valid format. + * Default is `true` for strict validation. * @return the DataFrame containing the result of the SQL query. */ + public fun DataFrame.Companion.readSqlQuery( dbConfig: DbConnectionConfig, sqlQuery: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, dbType: DbType? = null, + strictValidation: Boolean = true, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlQuery(connection, sqlQuery, limit, inferNullability, dbType) + return readSqlQuery(connection, sqlQuery, limit, inferNullability, dbType, strictValidation) } } @@ -206,6 +230,8 @@ public fun DataFrame.Companion.readSqlQuery( * @param [inferNullability] indicates how the column nullability should be inferred. * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, * in that case the [dbType] will be recognized from the [connection]. + * @param [strictValidation] if `true`, the method validates that the provided query is in a valid format. + * Default is `true` for strict validation. * @return the DataFrame containing the result of the SQL query. * * @see DriverManager.getConnection @@ -216,10 +242,15 @@ public fun DataFrame.Companion.readSqlQuery( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, dbType: DbType? = null, + strictValidation: Boolean = true, ): AnyFrame { - require(isValid(sqlQuery)) { - "SQL query should start from SELECT and contain one query for reading data without any manipulation. " + - "Also it should not contain any separators like `;`." + if (strictValidation) { + require(isValidSqlQuery(sqlQuery)) { + "SQL query should start from SELECT and contain one query for reading data without any manipulation. " + + "Also it should not contain any separators like `;`." + } + } else { + logger.warn { "Strict validation is disabled. Ensure the SQL query '$sqlQuery' is correct and safe." } } val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) @@ -228,8 +259,8 @@ public fun DataFrame.Companion.readSqlQuery( logger.debug { "Executing SQL query: $internalSqlQuery" } - connection.createStatement().use { st -> - st.executeQuery(internalSqlQuery).use { rs -> + connection.prepareStatement(internalSqlQuery).use { st -> + st.executeQuery().use { rs -> val tableColumns = getTableColumnsMetadata(rs) return fetchAndConvertDataFromResultSet(tableColumns, rs, determinedDbType, limit, inferNullability) } @@ -247,6 +278,8 @@ public fun DataFrame.Companion.readSqlQuery( * @param [inferNullability] indicates how the column nullability should be inferred. * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, * in that case the [dbType] will be recognized from the [DbConnectionConfig]. + * @param [strictValidation] if `true`, the method validates that the provided query or table name is in a valid format. + * Default is `true` for strict validation. * @return the DataFrame containing the result of the SQL query. */ public fun DbConnectionConfig.readDataFrame( @@ -254,6 +287,7 @@ public fun DbConnectionConfig.readDataFrame( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, dbType: DbType? = null, + strictValidation: Boolean = true, ): AnyFrame = when { isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery( @@ -262,6 +296,7 @@ public fun DbConnectionConfig.readDataFrame( limit, inferNullability, dbType, + strictValidation, ) isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable( @@ -270,6 +305,7 @@ public fun DbConnectionConfig.readDataFrame( limit, inferNullability, dbType, + strictValidation, ) else -> throw IllegalArgumentException( @@ -299,6 +335,8 @@ private fun isSqlTableName(sqlQueryOrTableName: String): Boolean { * @param [inferNullability] indicates how the column nullability should be inferred. * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, * in that case the [dbType] will be recognized from the [Connection]. + * @param [strictValidation] if `true`, the method validates that the provided query or table name is in a valid format. + * Default is `true` for strict validation. * @return the DataFrame containing the result of the SQL query. */ public fun Connection.readDataFrame( @@ -306,6 +344,7 @@ public fun Connection.readDataFrame( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, dbType: DbType? = null, + strictValidation: Boolean = true, ): AnyFrame = when { isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery( @@ -314,6 +353,7 @@ public fun Connection.readDataFrame( limit, inferNullability, dbType, + strictValidation, ) isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable( @@ -322,6 +362,7 @@ public fun Connection.readDataFrame( limit, inferNullability, dbType, + strictValidation, ) else -> throw IllegalArgumentException( @@ -329,12 +370,111 @@ public fun Connection.readDataFrame( ) } -/** SQL query is accepted only if it starts from SELECT */ -private fun isValid(sqlQuery: String): Boolean { +/** + * Checks if a given string contains forbidden patterns or keywords. + * Logs a clear and friendly message if any forbidden pattern is found. + */ +private fun containsForbiddenPatterns(input: String): Boolean { + // List of forbidden patterns or commands + val forbiddenPatterns = listOf( + ";", // Separator for SQL statements + "--", // Single-line comments + "/*", // Start of multi-line comments + "*/", // End of multi-line comments + "DROP", + "DELETE", + "INSERT", + "UPDATE", + "EXEC", + "EXECUTE", + "CREATE", + "ALTER", + "GRANT", + "REVOKE", + "MERGE", + ) + + for (pattern in forbiddenPatterns) { + if (input.contains(pattern)) { + logger.error { + "Validation failed: The input contains a forbidden element '$pattern'. " + + "Please review the input: '$input'." + } + return true + } + } + return false +} + +/** + * Validates if the SQL query is safe and starts with SELECT. + * Ensures proper syntax structure, checks for balanced quotes, and disallows dangerous commands or patterns. + */ +private fun isValidSqlQuery(sqlQuery: String): Boolean { val normalizedSqlQuery = sqlQuery.trim().uppercase() - return normalizedSqlQuery.startsWith(START_OF_READ_SQL_QUERY) && - !normalizedSqlQuery.contains(MULTIPLE_SQL_QUERY_SEPARATOR) + // Log the query being validated + logger.warn { "Validating SQL query: '$sqlQuery'" } + + // Ensure the query starts with "SELECT" + if (!normalizedSqlQuery.startsWith("SELECT")) { + logger.error { "Validation failed: The SQL query must start with 'SELECT'. Given query: '$sqlQuery'." } + return false + } + + // Validate against forbidden patterns + if (containsForbiddenPatterns(normalizedSqlQuery)) { + return false + } + + // Check if there are balanced quotes (single and double) + val singleQuotes = sqlQuery.count { it == '\'' } + val doubleQuotes = sqlQuery.count { it == '"' } + if (singleQuotes % 2 != 0) { + logger.error { + "Validation failed: Unbalanced single quotes in the SQL query. " + + "Please correct the query: '$sqlQuery'." + } + return false + } + if (doubleQuotes % 2 != 0) { + logger.error { + "Validation failed: Unbalanced double quotes in the SQL query. " + + "Please correct the query: '$sqlQuery'." + } + return false + } + + logger.warn { "SQL query validation succeeded for query: '$sqlQuery'." } + return true +} + +/** + * Validates if the given SQL table name is safe and logs any validation violations. + */ +private fun isValidTableName(tableName: String): Boolean { + val normalizedTableName = tableName.trim().uppercase() + + // Log the table name being validated + logger.warn { "Validating SQL table name: '$tableName'" } + + // Validate against forbidden patterns + if (containsForbiddenPatterns(normalizedTableName)) { + return false + } + + // Validate the table name structure: letters, numbers, underscores, and dots are allowed + val tableNameRegex = Regex(TABLE_NAME_VALID_PATTERN) + if (!tableNameRegex.matches(normalizedTableName)) { + logger.error { + "Validation failed: The table name contains invalid characters. " + + "Only letters, numbers, underscores, and dots are allowed. Provided name: '$tableName'." + } + return false + } + + logger.warn { "Table name validation passed for table: '$tableName'." } + return true } /** @@ -571,8 +711,8 @@ public fun DataFrame.Companion.getSchemaForSqlTable( val sqlQuery = "SELECT * FROM $tableName" val selectFirstRowQuery = determinedDbType.sqlQueryLimit(sqlQuery, limit = 1) - connection.createStatement().use { st -> - st.executeQuery(selectFirstRowQuery).use { rs -> + connection.prepareStatement(selectFirstRowQuery).use { st -> + st.executeQuery().use { rs -> val tableColumns = getTableColumnsMetadata(rs) return buildSchemaByTableColumns(tableColumns, determinedDbType) } @@ -616,8 +756,8 @@ public fun DataFrame.Companion.getSchemaForSqlQuery( ): DataFrameSchema { val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) - connection.createStatement().use { st -> - st.executeQuery(sqlQuery).use { rs -> + connection.prepareStatement(sqlQuery).use { st -> + st.executeQuery().use { rs -> val tableColumns = getTableColumnsMetadata(rs) return buildSchemaByTableColumns(tableColumns, determinedDbType) } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt index b5d8a931dd..408ca65dc9 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt @@ -599,25 +599,197 @@ class JdbcTest { } @Test - fun `read from table with name from reserved SQL keywords`() { - // Create table Sale + fun `readFromTable should reject invalid table names to prevent SQL injections`() { + // Invalid table names that attempt SQL injection + val invalidTableNames = listOf( + "Customer; DROP TABLE Customer", // Injection using semicolon + "Sale -- Comment", // Injection using single-line comment + "/* Multi-line comment */ Customer", // Injection using multi-line comment + "Sale WHERE 1=1", // Injection using always-true condition + "Sale UNION SELECT * FROM Customer", // UNION injection + ) + + invalidTableNames.forEach { tableName -> + shouldThrow { + DataFrame.readSqlTable(connection, tableName) + } + } + } + + @Test + fun `readSqlQuery should reject malicious SQL queries to prevent SQL injections`() { + // Malicious SQL queries attempting injection @Language("SQL") - val createAlterTableQuery = """ - CREATE TABLE "ALTER" ( - id INT PRIMARY KEY, - description TEXT - ) + val injectionComment = """ + SELECT * FROM Sale WHERE amount = 100.0 -- AND id = 5 + """ + + @Language("SQL") + val injectionMultilineComment = """ + SELECT * FROM Customer /* Possible malicious comment */ WHERE id = 1 + """ + + @Language("SQL") + val injectionSemicolon = """ + SELECT * FROM Sale WHERE amount = 500.0; DROP TABLE Customer + """ + + @Language("SQL") + val injectionSQLWithSingleQuote = """ + SELECT * FROM Sale WHERE id = 1 AND amount = 100.0 OR '1'='1 + """ + + @Language("SQL") + val injectionUsingDropCommand = """ + DROP TABLE Customer; SELECT * FROM Sale + """ + + val sqlInjectionQueries = listOf( + injectionComment, + injectionMultilineComment, + injectionSemicolon, + injectionSQLWithSingleQuote, + injectionUsingDropCommand, + ) + + sqlInjectionQueries.forEach { query -> + shouldThrow { + DataFrame.readSqlQuery(connection, query) + } + } + } + + @Test + fun `readFromTable should work with non-standard table names when strictValidation is disabled`() { + // Non-standard table names that are still valid but may appear strange + val nonStandardTableNames = listOf( + "`Customer With Space`", // Table name with spaces + "`Important-Data`", // Table name with hyphens + "`[123TableName]`", // Table name that resembles a special syntax + ) + + try { + // Create these tables to ensure they exist for the test + connection.createStatement().use { stmt -> + nonStandardTableNames.forEach { tableName -> + stmt.execute("CREATE TABLE IF NOT EXISTS $tableName (id INT, name VARCHAR(255))") + } + } + + // Read from these tables with strictValidation disabled + nonStandardTableNames.forEach { tableName -> + DataFrame.readSqlTable(connection, tableName, strictValidation = false) + } + } finally { + // Clean up by deleting all created tables + connection.createStatement().use { stmt -> + nonStandardTableNames.forEach { tableName -> + stmt.execute("DROP TABLE IF EXISTS $tableName") + } + } + } + } + + @Test + fun `read from Unicode table names`() { + val unicodeTableNames = listOf( + "Таблица", // Russian Cyrillic + "表", // Chinese character + "テーブル", // Japanese Katakana + "عربي", // Arabic + "Δοκιμή", // Greek + ) + + try { + // Create tables with Unicode names + connection.createStatement().use { stmt -> + unicodeTableNames.forEach { tableName -> + stmt.execute("CREATE TABLE IF NOT EXISTS `$tableName` (id INT PRIMARY KEY, name VARCHAR(255))") + stmt.execute("INSERT INTO `$tableName` (id, name) VALUES (1, 'TestName')") + } + } + + // Read from the tables and validate correctness + unicodeTableNames.forEach { tableName -> + val df = DataFrame.readSqlTable(connection, tableName) + df.rowsCount() shouldBe 1 + df[0][1] shouldBe "TestName" + } + } finally { + // Drop the Unicode tables + connection.createStatement().use { stmt -> + unicodeTableNames.forEach { tableName -> + stmt.execute("DROP TABLE IF EXISTS `$tableName`") + } + } + } + } + + @Test + fun `readSqlQuery should execute DROP TABLE when validation is disabled`() { + // Query to create a temporary test table + @Language("SQL") + val createTableQuery = """ + CREATE TABLE IF NOT EXISTS TestTable ( + id INT PRIMARY KEY, + data VARCHAR(255) + ) """ - connection.createStatement().execute(createAlterTableQuery) + // Query to drop the test table + @Language("SQL") + val dropTableQuery = """ + SELECT * FROM TestTable; DROP TABLE TestTable; + + """ + try { + // Create the test table + connection.createStatement().use { stmt -> + stmt.execute(createTableQuery) // Create table for the test case + } + + // Execute the DROP TABLE command with validation disabled + DataFrame.readSqlQuery(connection, dropTableQuery, strictValidation = false) + + // Verify that the table has been successfully dropped + connection.createStatement().use { stmt -> + shouldThrow { + stmt.executeQuery("SELECT * FROM TestTable") + } + } + } finally { + // Cleanup: Ensure the table is removed in case of failure + connection.createStatement().use { stmt -> + stmt.execute("DROP TABLE IF EXISTS TestTable") + } + } + } + + @Test + fun `read from table with name from reserved SQL keywords`() { @Language("SQL") - val selectFromWeirdTableSQL = """ - SELECT * from "ALTER" + val createAlterTableQuery = """ + CREATE TABLE "ALTER" ( + id INT PRIMARY KEY, + description TEXT + ) """ - DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL).rowsCount() shouldBe 0 - connection.createStatement().execute("DROP TABLE \"ALTER\"") + @Language("SQL") + val selectFromWeirdTableSQL = """SELECT * from "ALTER"""" + + try { + connection.createStatement().execute(createAlterTableQuery) + // with enabled strictValidation + shouldThrow { + DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL) + } + // with disabled strictValidation + DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL, strictValidation = false).rowsCount() shouldBe 0 + } finally { + connection.createStatement().execute("DROP TABLE IF EXISTS \"ALTER\"") + } } @Test