-
-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathPostgresDatabase+SQL.swift
109 lines (92 loc) · 3.99 KB
/
PostgresDatabase+SQL.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import Logging
import PostgresNIO
import SQLKit
extension PostgresDatabase {
@inlinable
public func sql(queryLogLevel: Logger.Level? = .debug) -> some SQLDatabase {
self.sql(encodingContext: .default, decodingContext: .default, queryLogLevel: queryLogLevel)
}
public func sql(
encodingContext: PostgresEncodingContext<some PostgresJSONEncoder>,
decodingContext: PostgresDecodingContext<some PostgresJSONDecoder>,
queryLogLevel: Logger.Level? = .debug
) -> some SQLDatabase {
PostgresSQLDatabase(database: self, encodingContext: encodingContext, decodingContext: decodingContext, queryLogLevel: queryLogLevel)
}
}
private struct PostgresSQLDatabase<PDatabase: PostgresDatabase, E: PostgresJSONEncoder, D: PostgresJSONDecoder> {
let database: PDatabase
let encodingContext: PostgresEncodingContext<E>
let decodingContext: PostgresDecodingContext<D>
let queryLogLevel: Logger.Level?
}
extension PostgresSQLDatabase: SQLDatabase, PostgresDatabase {
var logger: Logger {
self.database.logger
}
var eventLoop: any EventLoop {
self.database.eventLoop
}
var version: (any SQLDatabaseReportedVersion)? {
nil // PSQL doesn't send version in wire protocol, must use SQL to read it
}
var dialect: any SQLDialect {
PostgresDialect()
}
func execute(sql query: any SQLExpression, _ onRow: @escaping @Sendable (any SQLRow) -> ()) -> EventLoopFuture<Void> {
let (sql, binds) = self.serialize(query)
if let queryLogLevel = self.queryLogLevel {
self.logger.log(level: queryLogLevel, "Executing query", metadata: ["sql": .string(sql), "binds": .array(binds.map { .string("\($0)") })])
}
return self.eventLoop.makeCompletedFuture {
var bindings = PostgresBindings(capacity: binds.count)
for bind in binds {
try PostgresDataTranslation.encode(value: bind, in: self.encodingContext, to: &bindings)
}
return bindings
}.flatMap { bindings in self.database.withConnection {
$0.query(
.init(unsafeSQL: sql, binds: bindings),
logger: $0.logger,
{ onRow($0.sql(decodingContext: self.decodingContext)) }
)
} }.map { _ in }
}
func execute(
sql query: any SQLExpression,
_ onRow: @escaping @Sendable (any SQLRow) -> ()
) async throws {
let (sql, binds) = self.serialize(query)
if let queryLogLevel = self.queryLogLevel {
self.logger.log(level: queryLogLevel, "Executing query", metadata: ["sql": .string(sql), "binds": .array(binds.map { .string("\($0)") })])
}
var bindings = PostgresBindings(capacity: binds.count)
for bind in binds {
try PostgresDataTranslation.encode(value: bind, in: self.encodingContext, to: &bindings)
}
_ = try await self.database.withConnection {
$0.query(
.init(unsafeSQL: sql, binds: bindings),
logger: $0.logger,
{ onRow($0.sql(decodingContext: self.decodingContext)) }
)
}.get()
}
func send(_ request: any PostgresRequest, logger: Logger) -> EventLoopFuture<Void> {
self.database.send(request, logger: logger)
}
func withConnection<T>(_ closure: @escaping (PostgresConnection) -> EventLoopFuture<T>) -> EventLoopFuture<T> {
self.database.withConnection(closure)
}
func withSession<R: Sendable>(_ closure: @escaping @Sendable (any SQLDatabase) async throws -> R) async throws -> R {
try await self.withConnection { c in
c.eventLoop.makeFutureWithTask {
try await closure(c.sql(
encodingContext: self.encodingContext,
decodingContext: self.decodingContext,
queryLogLevel: self.queryLogLevel
))
}
}.get()
}
}