Skip to content

Commit bb189df

Browse files
committed
initial commit
- propogate dialect config into context - use it to escape table name
1 parent b6f864f commit bb189df

22 files changed

+145
-36
lines changed

scalasql/core/src/Context.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ trait Context {
2424
*/
2525
def config: Config
2626

27+
def dialectConfig: DialectConfig
28+
2729
def withFromNaming(fromNaming: Map[Context.From, String]): Context
2830
def withExprNaming(exprNaming: Map[Expr.Identity, SqlStr]): Context
2931
}
@@ -56,7 +58,8 @@ object Context {
5658
case class Impl(
5759
fromNaming: Map[From, String],
5860
exprNaming: Map[Expr.Identity, SqlStr],
59-
config: Config
61+
config: Config,
62+
dialectConfig: DialectConfig
6063
) extends Context {
6164
def withFromNaming(fromNaming: Map[From, String]): Context = copy(fromNaming = fromNaming)
6265

@@ -93,7 +96,7 @@ object Context {
9396
.map { case (e, s) => (e, sql"${SqlStr.raw(newFromNaming(t), Array(e))}.$s") }
9497
}
9598

96-
Context.Impl(newFromNaming, newExprNaming, prevContext.config)
99+
Context.Impl(newFromNaming, newExprNaming, prevContext.config, prevContext.dialectConfig)
97100
}
98101

99102
}

scalasql/core/src/DbApi.scala

+16-11
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,22 @@ trait DbApi extends AutoCloseable {
123123

124124
object DbApi {
125125

126-
def unpackQueryable[R, Q](query: Q, qr: Queryable[Q, R], config: Config) = {
127-
val ctx = Context.Impl(Map(), Map(), config)
126+
def unpackQueryable[R, Q](
127+
query: Q,
128+
qr: Queryable[Q, R],
129+
config: Config,
130+
dialectConfig: DialectConfig
131+
) = {
132+
val ctx = Context.Impl(Map(), Map(), config, dialectConfig)
128133
val flattened = SqlStr.flatten(qr.renderSql(query, ctx))
129134
flattened
130135
}
131136

132-
def renderSql[Q, R](query: Q, config: Config, castParams: Boolean = false)(
137+
def renderSql[Q, R](query: Q, config: Config, dialectConfig: DialectConfig)(
133138
implicit qr: Queryable[Q, R]
134139
): String = {
135-
val flattened = unpackQueryable(query, qr, config)
136-
flattened.renderSql(castParams)
140+
val flattened = unpackQueryable(query, qr, config, dialectConfig)
141+
flattened.renderSql(dialectConfig.castParams)
137142
}
138143

139144
/**
@@ -196,7 +201,7 @@ object DbApi {
196201
lineNum: sourcecode.Line
197202
): R = {
198203

199-
val flattened = unpackQueryable(query, qr, config)
204+
val flattened = unpackQueryable(query, qr, config, dialect)
200205
if (qr.isGetGeneratedKeys(query).nonEmpty)
201206
updateGetGeneratedKeysSql(flattened)(qr.isGetGeneratedKeys(query).get, fileName, lineNum)
202207
.asInstanceOf[R]
@@ -225,7 +230,7 @@ object DbApi {
225230
fileName: sourcecode.FileName,
226231
lineNum: sourcecode.Line
227232
): Generator[R] = {
228-
val flattened = unpackQueryable(query, qr, config)
233+
val flattened = unpackQueryable(query, qr, config, dialect)
229234
streamFlattened0(
230235
r => {
231236
qr.asInstanceOf[Queryable[Q, R]].construct(query, r) match {
@@ -276,7 +281,7 @@ object DbApi {
276281
): Int = {
277282
val flattened = SqlStr.flatten(sql)
278283
runRawUpdate0(
279-
flattened.renderSql(DialectConfig.castParams(dialect)),
284+
flattened.renderSql(dialect.castParams),
280285
flattenParamPuts(flattened),
281286
fetchSize,
282287
queryTimeoutSeconds,
@@ -296,7 +301,7 @@ object DbApi {
296301
): IndexedSeq[R] = {
297302
val flattened = SqlStr.flatten(sql)
298303
runRawUpdateGetGeneratedKeys0(
299-
flattened.renderSql(DialectConfig.castParams(dialect)),
304+
flattened.renderSql(dialect.castParams),
300305
flattenParamPuts(flattened),
301306
fetchSize,
302307
queryTimeoutSeconds,
@@ -382,7 +387,7 @@ object DbApi {
382387
lineNum: sourcecode.Line
383388
) = streamRaw0(
384389
construct,
385-
flattened.renderSql(DialectConfig.castParams(dialect)),
390+
flattened.renderSql(dialect.castParams),
386391
flattenParamPuts(flattened),
387392
fetchSize,
388393
queryTimeoutSeconds,
@@ -508,7 +513,7 @@ object DbApi {
508513
def renderSql[Q, R](query: Q, castParams: Boolean = false)(
509514
implicit qr: Queryable[Q, R]
510515
): String = {
511-
DbApi.renderSql(query, config, castParams)
516+
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
512517
}
513518

514519
val savepointStack = collection.mutable.ArrayDeque.empty[java.sql.Savepoint]

scalasql/core/src/DbClient.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ object DbClient {
4444
def renderSql[Q, R](query: Q, castParams: Boolean = false)(
4545
implicit qr: Queryable[Q, R]
4646
): String = {
47-
DbApi.renderSql(query, config, castParams)
47+
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
4848
}
4949

5050
def transaction[T](block: DbApi.Txn => T): T = {
@@ -74,7 +74,7 @@ object DbClient {
7474
def renderSql[Q, R](query: Q, castParams: Boolean = false)(
7575
implicit qr: Queryable[Q, R]
7676
): String = {
77-
DbApi.renderSql(query, config, castParams)
77+
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
7878
}
7979

8080
private def withConnection[T](f: DbClient.Connection => T): T = {

scalasql/core/src/DialectConfig.scala

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package scalasql.core
22

3-
trait DialectConfig {
4-
protected def dialectCastParams: Boolean
5-
}
3+
trait DialectConfig { that =>
4+
def castParams: Boolean
5+
def escape(str: String): String
6+
7+
def withCastParams(params: Boolean) = new DialectConfig {
8+
def castParams: Boolean = params
9+
10+
def escape(str: String): String = that.escape(str)
611

7-
object DialectConfig {
8-
def castParams(d: DialectConfig) = d.dialectCastParams
12+
}
913
}

scalasql/query/src/Delete.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ object Delete {
2424
class Renderer(table: TableRef, expr: Expr[Boolean], prevContext: Context) {
2525
implicit val implicitCtx: Context = Context.compute(prevContext, Nil, Some(table))
2626
lazy val tableNameStr =
27-
SqlStr.raw(Table.resolve(table.value))
27+
SqlStr.raw(Table.fullIdentifier(table.value))
2828

2929
def render() = sql"DELETE FROM $tableNameStr WHERE $expr"
3030
}

scalasql/query/src/From.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TableRef(val value: Table.Base) extends From {
1515
def fromExprAliases(prevContext: Context): Seq[(Expr.Identity, SqlStr)] = Nil
1616

1717
def renderSql(name: SqlStr, prevContext: Context, liveExprs: LiveExprs) = {
18-
val resolvedTable = Table.resolve(value)(prevContext)
18+
val resolvedTable = Table.fullIdentifier(value)(prevContext)
1919
SqlStr.raw(resolvedTable + sql" " + name)
2020
}
2121
}

scalasql/query/src/InsertColumns.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ object InsertColumns {
2424
protected def expr: V[Column] = WithSqlExpr.get(insert)
2525

2626
private[scalasql] override def renderSql(ctx: Context) =
27-
new Renderer(columns, ctx, valuesLists, Table.resolve(table.value)(ctx)).render()
27+
new Renderer(columns, ctx, valuesLists, Table.fullIdentifier(table.value)(ctx)).render()
2828

2929
override protected def queryConstruct(args: Queryable.ResultSetIterator): Int =
3030
args.get(IntType)

scalasql/query/src/InsertSelect.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ object InsertSelect {
2020
def table = insert.table
2121

2222
private[scalasql] override def renderSql(ctx: Context) =
23-
new Renderer(select, select.qr.walkExprs(columns), ctx, Table.resolve(table.value)(ctx))
23+
new Renderer(
24+
select,
25+
select.qr.walkExprs(columns),
26+
ctx,
27+
Table.fullIdentifier(table.value)(ctx)
28+
)
2429
.render()
2530

2631
override protected def queryConstruct(args: Queryable.ResultSetIterator): Int =

scalasql/query/src/InsertValues.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ object InsertValues {
2424

2525
override private[scalasql] def renderSql(ctx: Context): SqlStr = {
2626
new Renderer(
27-
Table.resolve(insert.table.value)(ctx),
27+
Table.fullIdentifier(insert.table.value)(ctx),
2828
Table.labels(insert.table.value),
2929
values,
3030
qr,

scalasql/query/src/Table.scala

+17-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ abstract class Table[V[_[_]]]()(implicit name: sourcecode.Name, metadata0: Table
1414

1515
protected[scalasql] def schemaName = ""
1616

17+
protected[scalasql] def escape: Boolean = false
18+
1719
protected implicit def tableSelf: Table[V] = this
1820

1921
protected def tableMetadata: Table.Metadata[V] = metadata0
@@ -50,11 +52,21 @@ object Table {
5052
def name(t: Table.Base) = t.tableName
5153
def labels(t: Table.Base) = t.tableLabels
5254
def columnNameOverride[V[_[_]]](t: Table.Base)(s: String) = t.tableColumnNameOverride(s)
53-
def resolve(t: Table.Base)(implicit context: Context) = {
54-
val mappedTableName = context.config.tableNameMapper(t.tableName)
55+
def identifier(t: Table.Base)(implicit context: Context): String = {
56+
context.config.tableNameMapper.andThen { str =>
57+
if (t.escape) {
58+
context.dialectConfig.escape(str)
59+
} else {
60+
str
61+
}
62+
}(t.tableName)
63+
}
64+
def fullIdentifier(
65+
t: Table.Base
66+
)(implicit context: Context): String = {
5567
t.schemaName match {
56-
case "" => mappedTableName
57-
case str => s"$str." + mappedTableName
68+
case "" => identifier(t)
69+
case str => s"$str." + identifier(t)
5870
}
5971
}
6072
trait Base {
@@ -66,6 +78,7 @@ object Table {
6678
protected[scalasql] def tableName: String
6779
protected[scalasql] def schemaName: String
6880
protected[scalasql] def tableLabels: Seq[String]
81+
protected[scalasql] def escape: Boolean
6982

7083
/**
7184
* Customizations to the column names of this table before processing,

scalasql/query/src/Update.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ object Update {
9494
implicit lazy val implicitCtx: Context = Context.compute(prevContext, froms, Some(table))
9595

9696
lazy val tableName =
97-
SqlStr.raw(Table.resolve(table.value))
97+
SqlStr.raw(Table.fullIdentifier(table.value))
9898

9999
lazy val updateList = set0.map { case assign =>
100100
val kStr = SqlStr.raw(prevContext.config.columnNameMapper(assign.column.name))

scalasql/src/dialects/H2Dialect.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ import java.sql.PreparedStatement
2727

2828
trait H2Dialect extends Dialect {
2929

30-
protected def dialectCastParams = true
30+
def castParams = true
31+
32+
def escape(str: String) = s"\"${str.toUpperCase()}\""
3133

3234
override implicit def EnumType[T <: Enumeration#Value](
3335
implicit constructor: String => T

scalasql/src/dialects/MySqlDialect.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ import scala.reflect.ClassTag
4343
import scalasql.query.Select
4444

4545
trait MySqlDialect extends Dialect {
46-
protected def dialectCastParams = false
46+
def castParams = false
47+
48+
def escape(str: String) = s"`$str`"
4749

4850
override implicit def ByteType: TypeMapper[Byte] = new MySqlByteType
4951
class MySqlByteType extends ByteType { override def castTypeString = "SIGNED" }

scalasql/src/dialects/PostgresDialect.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ import scalasql.operations.{ConcatOps, HyperbolicMathOps, MathOps, PadOps, TrimO
1717

1818
trait PostgresDialect extends Dialect with ReturningDialect with OnConflictOps {
1919

20-
protected def dialectCastParams = false
20+
def castParams = false
21+
22+
def escape(str: String) = s"\"$str\""
2123

2224
override implicit def ByteType: TypeMapper[Byte] = new PostgresByteType
2325
class PostgresByteType extends ByteType { override def castTypeString = "INTEGER" }

scalasql/src/dialects/SqliteDialect.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ import scalasql.operations.TrimOps
1818
import java.time.{Instant, LocalDate, LocalDateTime}
1919

2020
trait SqliteDialect extends Dialect with ReturningDialect with OnConflictOps {
21-
protected def dialectCastParams = false
21+
def castParams = false
22+
23+
def escape(str: String) = s"\"$str\""
2224

2325
override implicit def LocalDateTimeType: TypeMapper[LocalDateTime] = new SqliteLocalDateTimeType
2426
class SqliteLocalDateTimeType extends LocalDateTimeType {

scalasql/test/resources/h2-customer-schema.sql

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ DROP TABLE IF EXISTS nested CASCADE;
1111
DROP TABLE IF EXISTS enclosing CASCADE;
1212
DROP TABLE IF EXISTS invoice CASCADE;
1313
DROP SCHEMA IF EXISTS otherschema CASCADE;
14+
DROP TABLE IF EXISTS "SELECT" CASCADE;
1415

1516
CREATE TABLE buyer (
1617
id INTEGER AUTO_INCREMENT PRIMARY KEY,
@@ -98,4 +99,9 @@ CREATE TABLE otherschema.invoice(
9899
id INTEGER AUTO_INCREMENT PRIMARY KEY,
99100
total DECIMAL(20, 2),
100101
vendor_name VARCHAR(256)
101-
);
102+
);
103+
104+
CREATE TABLE "SELECT"(
105+
id INTEGER,
106+
name VARCHAR(256)
107+
)

scalasql/test/resources/mysql-customer-schema.sql

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ DROP TABLE IF EXISTS `non_round_trip_types` CASCADE;
1010
DROP TABLE IF EXISTS `opt_cols` CASCADE;
1111
DROP TABLE IF EXISTS `nested` CASCADE;
1212
DROP TABLE IF EXISTS `enclosing` CASCADE;
13+
DROP TABLE IF EXISTS `select` CASCADE;
14+
1315
SET FOREIGN_KEY_CHECKS = 1;
1416

1517
CREATE TABLE buyer (
@@ -90,3 +92,8 @@ CREATE TABLE enclosing(
9092
foo_id INTEGER,
9193
my_boolean BOOLEAN
9294
);
95+
96+
CREATE TABLE `select`(
97+
id INTEGER,
98+
name VARCHAR(256)
99+
);

scalasql/test/resources/postgres-customer-schema.sql

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DROP TABLE IF EXISTS enclosing CASCADE;
1212
DROP TABLE IF EXISTS invoice CASCADE;
1313
DROP TYPE IF EXISTS my_enum CASCADE;
1414
DROP SCHEMA IF EXISTS otherschema CASCADE;
15+
DROP TABLE IF EXISTS "select" CASCADE;
1516

1617
CREATE TABLE buyer (
1718
id SERIAL PRIMARY KEY,
@@ -103,3 +104,8 @@ CREATE TABLE otherschema.invoice(
103104
total DECIMAL(20, 2),
104105
vendor_name VARCHAR(256)
105106
);
107+
108+
CREATE TABLE "select"(
109+
id INTEGER,
110+
name VARCHAR(256)
111+
);

scalasql/test/resources/sqlite-customer-schema.sql

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DROP TABLE IF EXISTS non_round_trip_types;
99
DROP TABLE IF EXISTS nested;
1010
DROP TABLE IF EXISTS enclosing;
1111
DROP TABLE IF EXISTS opt_cols;
12+
DROP TABLE IF EXISTS "select";
1213

1314
CREATE TABLE buyer (
1415
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -91,3 +92,8 @@ CREATE TABLE enclosing(
9192
foo_id INTEGER,
9293
my_boolean BOOLEAN
9394
);
95+
96+
CREATE TABLE "select"(
97+
id INTEGER,
98+
name VARCHAR(256)
99+
)

0 commit comments

Comments
 (0)