Skip to content

Commit f6c7211

Browse files
authored
Aggregate functions support (#262)
While this landed in StructuredQueries proper, we need to expose the functionality to GRDB.
1 parent b682548 commit f6c7211

File tree

5 files changed

+245
-45
lines changed

5 files changed

+245
-45
lines changed

Package.resolved

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ let package = Package(
3636
.package(url: "https://github.com/pointfreeco/swift-snapshot-testing", from: "1.18.4"),
3737
.package(
3838
url: "https://github.com/pointfreeco/swift-structured-queries",
39-
from: "0.22.3",
39+
from: "0.24.0",
4040
traits: [
4141
.trait(name: "StructuredQueriesTagged", condition: .when(traits: ["SQLiteDataTagged"]))
4242
]

[email protected]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ let package = Package(
2828
.package(url: "https://github.com/pointfreeco/swift-dependencies", from: "1.9.0"),
2929
.package(url: "https://github.com/pointfreeco/swift-sharing", from: "2.3.0"),
3030
.package(url: "https://github.com/pointfreeco/swift-snapshot-testing", from: "1.18.4"),
31-
.package(url: "https://github.com/pointfreeco/swift-structured-queries", from: "0.22.3"),
31+
.package(url: "https://github.com/pointfreeco/swift-structured-queries", from: "0.24.0"),
3232
.package(url: "https://github.com/pointfreeco/xctest-dynamic-overlay", from: "1.5.0"),
3333
],
3434
targets: [

Sources/SQLiteData/StructuredQueries+GRDB/CustomFunctions.swift

Lines changed: 163 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@ import Foundation
22
import GRDBSQLite
33

44
extension Database {
5-
/// Adds a user-defined `@DatabaseFunction` to a connection.
5+
/// Adds a user-defined scalar `@DatabaseFunction` to a connection.
66
///
7-
/// - Parameter function: A database function to add.
7+
/// - Parameter function: A scalar database function to add.
88
public func add(function: some ScalarDatabaseFunction) {
99
sqlite3_create_function_v2(
1010
sqliteConnection,
1111
function.name,
1212
function.argumentCount,
1313
function.textEncoding,
14-
Unmanaged.passRetained(ScalarDatabaseFunctionBox(function)).toOpaque(),
14+
Unmanaged.passRetained(ScalarDatabaseFunctionDefinition(function)).toOpaque(),
1515
{ context, argumentCount, arguments in
1616
do {
1717
var decoder = SQLiteFunctionDecoder(argumentCount: argumentCount, arguments: arguments)
18-
try Unmanaged<ScalarDatabaseFunctionBox>
18+
try Unmanaged<ScalarDatabaseFunctionDefinition>
1919
.fromOpaque(sqlite3_user_data(context))
2020
.takeUnretainedValue()
2121
.function
@@ -27,17 +27,56 @@ extension Database {
2727
},
2828
nil,
2929
nil,
30-
{ box in
31-
guard let box else { return }
32-
Unmanaged<ScalarDatabaseFunctionBox>.fromOpaque(box).release()
30+
{ context in
31+
guard let context else { return }
32+
Unmanaged<ScalarDatabaseFunctionDefinition>.fromOpaque(context).release()
33+
}
34+
)
35+
}
36+
37+
/// Adds a user-defined aggregate `@DatabaseFunction` to a connection.
38+
///
39+
/// - Parameter function: An aggregate database function to add.
40+
public func add(function: some AggregateDatabaseFunction) {
41+
let body = Unmanaged.passRetained(AggregateDatabaseFunctionDefinition(function)).toOpaque()
42+
sqlite3_create_function_v2(
43+
sqliteConnection,
44+
function.name,
45+
function.argumentCount,
46+
function.textEncoding,
47+
body,
48+
nil,
49+
{ context, argumentCount, arguments in
50+
var decoder = SQLiteFunctionDecoder(argumentCount: argumentCount, arguments: arguments)
51+
let function = AggregateDatabaseFunctionContext[context].takeUnretainedValue()
52+
do {
53+
try function.iterator.step(&decoder)
54+
} catch {
55+
sqlite3_result_error(context, error.localizedDescription, -1)
56+
}
57+
},
58+
{ context in
59+
let unmanagedFunction = AggregateDatabaseFunctionContext[context]
60+
let function = unmanagedFunction.takeUnretainedValue()
61+
unmanagedFunction.release()
62+
function.iterator.finish()
63+
do {
64+
try function.iterator.result.result(db: context)
65+
} catch {
66+
sqlite3_result_error(context, error.localizedDescription, -1)
67+
}
68+
},
69+
{ context in
70+
guard let context else { return }
71+
Unmanaged<AggregateDatabaseFunctionContext>.fromOpaque(context).release()
3372
}
3473
)
3574
}
3675

3776
/// Deletes a user-defined `@DatabaseFunction` from a connection.
3877
///
3978
/// - Parameter function: A database function to delete.
40-
public func remove(function: some ScalarDatabaseFunction) {
79+
public func remove(function: some DatabaseFunction) {
4180
sqlite3_create_function_v2(
4281
sqliteConnection,
4382
function.name,
@@ -52,7 +91,7 @@ extension Database {
5291
}
5392
}
5493

55-
extension ScalarDatabaseFunction {
94+
extension DatabaseFunction {
5695
fileprivate var argumentCount: Int32 {
5796
Int32(argumentCount ?? -1)
5897
}
@@ -62,41 +101,132 @@ extension ScalarDatabaseFunction {
62101
}
63102
}
64103

65-
private final class ScalarDatabaseFunctionBox {
104+
private final class ScalarDatabaseFunctionDefinition {
66105
let function: any ScalarDatabaseFunction
67106
init(_ function: some ScalarDatabaseFunction) {
68107
self.function = function
69108
}
70109
}
71110

72-
extension [QueryBinding] {
73-
fileprivate init(argumentCount: Int32, arguments: UnsafeMutablePointer<OpaquePointer?>?) {
74-
self = (0..<argumentCount).map { offset in
75-
let value = arguments?[Int(offset)]
76-
switch sqlite3_value_type(value) {
77-
case SQLITE_BLOB:
78-
if let blob = sqlite3_value_blob(value) {
79-
let count = Int(sqlite3_value_bytes(value))
80-
let buffer = UnsafeRawBufferPointer(start: blob, count: count)
81-
return .blob([UInt8](buffer))
82-
} else {
83-
return .blob([])
111+
private final class AggregateDatabaseFunctionDefinition {
112+
let function: any AggregateDatabaseFunction
113+
init(_ function: some AggregateDatabaseFunction) {
114+
self.function = function
115+
}
116+
}
117+
118+
private final class AggregateDatabaseFunctionContext {
119+
static subscript(context: OpaquePointer?) -> Unmanaged<AggregateDatabaseFunctionContext> {
120+
let size = MemoryLayout<Unmanaged<AggregateDatabaseFunctionContext>>.size
121+
let pointer = sqlite3_aggregate_context(context, Int32(size))!
122+
if pointer.load(as: Int.self) == 0 {
123+
let definition = Unmanaged<AggregateDatabaseFunctionDefinition>
124+
.fromOpaque(sqlite3_user_data(context))
125+
.takeUnretainedValue()
126+
let context = AggregateDatabaseFunctionContext(definition.function)
127+
let unmanagedContext = Unmanaged.passRetained(context)
128+
pointer
129+
.assumingMemoryBound(to: Unmanaged<AggregateDatabaseFunctionContext>.self)
130+
.pointee = unmanagedContext
131+
return unmanagedContext
132+
} else {
133+
return
134+
pointer
135+
.assumingMemoryBound(to: Unmanaged<AggregateDatabaseFunctionContext>.self)
136+
.pointee
137+
}
138+
}
139+
let iterator: any AggregateDatabaseFunctionIteratorProtocol
140+
init(_ body: some AggregateDatabaseFunction) {
141+
self.iterator = AggregateDatabaseFunctionIterator(body)
142+
}
143+
}
144+
145+
private protocol AggregateDatabaseFunctionIteratorProtocol<Body> {
146+
associatedtype Body: AggregateDatabaseFunction
147+
148+
var body: Body { get }
149+
var stream: Stream<Body.Element> { get }
150+
func start()
151+
func step(_ decoder: inout some QueryDecoder) throws
152+
func finish()
153+
var result: QueryBinding { get throws }
154+
}
155+
156+
private final class AggregateDatabaseFunctionIterator<
157+
Body: AggregateDatabaseFunction
158+
>: AggregateDatabaseFunctionIteratorProtocol {
159+
let body: Body
160+
let stream = Stream<Body.Element>()
161+
let queue: DispatchQueue
162+
var _result: QueryBinding?
163+
init(_ body: Body) {
164+
self.body = body
165+
self.queue = DispatchQueue(
166+
label: "co.pointfree.StructuredQueriesSQLite.AggregateDatabaseFunction.\(body.name)"
167+
)
168+
nonisolated(unsafe) let iterator: any AggregateDatabaseFunctionIteratorProtocol = self
169+
queue.async {
170+
iterator.start()
171+
}
172+
}
173+
func start() {
174+
do {
175+
_result = try body.invoke(stream)
176+
} catch {
177+
_result = .invalid(error)
178+
}
179+
}
180+
func step(_ decoder: inout some QueryDecoder) throws {
181+
try stream.send(body.step(&decoder))
182+
}
183+
func finish() {
184+
stream.finish()
185+
}
186+
var result: QueryBinding {
187+
get throws {
188+
while true {
189+
if let result = queue.sync(execute: { _result }) {
190+
return result
84191
}
85-
case SQLITE_FLOAT:
86-
return .double(sqlite3_value_double(value))
87-
case SQLITE_INTEGER:
88-
return .int(sqlite3_value_int64(value))
89-
case SQLITE_NULL:
90-
return .null
91-
case SQLITE_TEXT:
92-
return .text(String(cString: UnsafePointer(sqlite3_value_text(value))))
93-
default:
94-
return .invalid(UnknownType())
95192
}
96193
}
97194
}
195+
}
196+
197+
private final class Stream<Element>: Sequence {
198+
let condition = NSCondition()
199+
private var buffer: [Element] = []
200+
private var isFinished = false
98201

99-
private struct UnknownType: Error {}
202+
func send(_ element: Element) {
203+
condition.withLock {
204+
buffer.append(element)
205+
condition.signal()
206+
}
207+
}
208+
209+
func finish() {
210+
condition.withLock {
211+
isFinished = true
212+
condition.broadcast()
213+
}
214+
}
215+
216+
func makeIterator() -> Iterator { Iterator(base: self) }
217+
218+
struct Iterator: IteratorProtocol {
219+
fileprivate let base: Stream
220+
mutating func next() -> Element? {
221+
base.condition.withLock {
222+
while base.buffer.isEmpty && !base.isFinished {
223+
base.condition.wait()
224+
}
225+
guard !base.buffer.isEmpty else { return nil }
226+
return base.buffer.removeFirst()
227+
}
228+
}
229+
}
100230
}
101231

102232
extension QueryBinding {
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import DependenciesTestSupport
2+
import Foundation
3+
import SQLiteData
4+
import SQLiteDataTestSupport
5+
import Testing
6+
7+
struct DatabaseFunctionTests {
8+
@DatabaseFunction
9+
func exclaim(_ text: String) -> String {
10+
text + "!"
11+
}
12+
@Test func scalarFunction() async throws {
13+
var configuration = Configuration()
14+
configuration.prepareDatabase { db in
15+
db.add(function: $exclaim)
16+
}
17+
let database = try DatabaseQueue(configuration: configuration)
18+
assertQuery(Values($exclaim("Blob")), database: database) {
19+
"""
20+
┌─────────┐
21+
"Blob!"
22+
└─────────┘
23+
"""
24+
}
25+
}
26+
27+
@Test(.dependency(\.defaultDatabase, try .database())) func aggregateFunction() async throws {
28+
assertQuery(Record.select { $sum($0.id) }) {
29+
"""
30+
┌───┐
31+
│ 6 │
32+
└───┘
33+
"""
34+
}
35+
}
36+
}
37+
38+
@Table
39+
private struct Record: Equatable {
40+
let id: Int
41+
}
42+
43+
@DatabaseFunction
44+
func sum(_ xs: some Sequence<Int>) -> Int {
45+
xs.reduce(0, +)
46+
}
47+
48+
extension DatabaseWriter where Self == DatabaseQueue {
49+
fileprivate static func database() throws -> DatabaseQueue {
50+
var configuration = Configuration()
51+
configuration.prepareDatabase { db in
52+
db.add(function: $sum)
53+
}
54+
let database = try DatabaseQueue(configuration: configuration)
55+
try database.write { db in
56+
try #sql(
57+
"""
58+
CREATE TABLE "records" (
59+
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT
60+
)
61+
"""
62+
)
63+
.execute(db)
64+
for _ in 1...3 {
65+
_ = try Record.insert { Record.Draft() }.execute(db)
66+
}
67+
}
68+
return database
69+
}
70+
}

0 commit comments

Comments
 (0)