@@ -2,20 +2,20 @@ import Foundation
22import GRDBSQLite
33
44extension Database {
5- /// Adds a user-defined `@DatabaseFunction` to a connection.
5+ /// Adds a user-defined scalar `@DatabaseFunction` to a connection.
66 ///
7- /// - Parameter function: A database function to add.
7+ /// - Parameter function: A scalar database function to add.
88 public func add( function: some ScalarDatabaseFunction ) {
99 sqlite3_create_function_v2 (
1010 sqliteConnection,
1111 function. name,
1212 function. argumentCount,
1313 function. textEncoding,
14- Unmanaged . passRetained ( ScalarDatabaseFunctionBox ( function) ) . toOpaque ( ) ,
14+ Unmanaged . passRetained ( ScalarDatabaseFunctionDefinition ( function) ) . toOpaque ( ) ,
1515 { context, argumentCount, arguments in
1616 do {
1717 var decoder = SQLiteFunctionDecoder ( argumentCount: argumentCount, arguments: arguments)
18- try Unmanaged < ScalarDatabaseFunctionBox >
18+ try Unmanaged < ScalarDatabaseFunctionDefinition >
1919 . fromOpaque ( sqlite3_user_data ( context) )
2020 . takeUnretainedValue ( )
2121 . function
@@ -27,17 +27,56 @@ extension Database {
2727 } ,
2828 nil ,
2929 nil ,
30- { box in
31- guard let box else { return }
32- Unmanaged < ScalarDatabaseFunctionBox > . fromOpaque ( box) . release ( )
30+ { context in
31+ guard let context else { return }
32+ Unmanaged < ScalarDatabaseFunctionDefinition > . fromOpaque ( context) . release ( )
33+ }
34+ )
35+ }
36+
37+ /// Adds a user-defined aggregate `@DatabaseFunction` to a connection.
38+ ///
39+ /// - Parameter function: An aggregate database function to add.
40+ public func add( function: some AggregateDatabaseFunction ) {
41+ let body = Unmanaged . passRetained ( AggregateDatabaseFunctionDefinition ( function) ) . toOpaque ( )
42+ sqlite3_create_function_v2 (
43+ sqliteConnection,
44+ function. name,
45+ function. argumentCount,
46+ function. textEncoding,
47+ body,
48+ nil ,
49+ { context, argumentCount, arguments in
50+ var decoder = SQLiteFunctionDecoder ( argumentCount: argumentCount, arguments: arguments)
51+ let function = AggregateDatabaseFunctionContext [ context] . takeUnretainedValue ( )
52+ do {
53+ try function. iterator. step ( & decoder)
54+ } catch {
55+ sqlite3_result_error ( context, error. localizedDescription, - 1 )
56+ }
57+ } ,
58+ { context in
59+ let unmanagedFunction = AggregateDatabaseFunctionContext [ context]
60+ let function = unmanagedFunction. takeUnretainedValue ( )
61+ unmanagedFunction. release ( )
62+ function. iterator. finish ( )
63+ do {
64+ try function. iterator. result. result ( db: context)
65+ } catch {
66+ sqlite3_result_error ( context, error. localizedDescription, - 1 )
67+ }
68+ } ,
69+ { context in
70+ guard let context else { return }
71+ Unmanaged < AggregateDatabaseFunctionContext > . fromOpaque ( context) . release ( )
3372 }
3473 )
3574 }
3675
3776 /// Deletes a user-defined `@DatabaseFunction` from a connection.
3877 ///
3978 /// - Parameter function: A database function to delete.
40- public func remove( function: some ScalarDatabaseFunction ) {
79+ public func remove( function: some DatabaseFunction ) {
4180 sqlite3_create_function_v2 (
4281 sqliteConnection,
4382 function. name,
@@ -52,7 +91,7 @@ extension Database {
5291 }
5392}
5493
55- extension ScalarDatabaseFunction {
94+ extension DatabaseFunction {
5695 fileprivate var argumentCount : Int32 {
5796 Int32 ( argumentCount ?? - 1 )
5897 }
@@ -62,41 +101,132 @@ extension ScalarDatabaseFunction {
62101 }
63102}
64103
65- private final class ScalarDatabaseFunctionBox {
104+ private final class ScalarDatabaseFunctionDefinition {
66105 let function : any ScalarDatabaseFunction
67106 init ( _ function: some ScalarDatabaseFunction ) {
68107 self . function = function
69108 }
70109}
71110
72- extension [ QueryBinding ] {
73- fileprivate init ( argumentCount: Int32 , arguments: UnsafeMutablePointer < OpaquePointer ? > ? ) {
74- self = ( 0 ..< argumentCount) . map { offset in
75- let value = arguments ? [ Int ( offset) ]
76- switch sqlite3_value_type ( value) {
77- case SQLITE_BLOB:
78- if let blob = sqlite3_value_blob ( value) {
79- let count = Int ( sqlite3_value_bytes ( value) )
80- let buffer = UnsafeRawBufferPointer ( start: blob, count: count)
81- return . blob( [ UInt8] ( buffer) )
82- } else {
83- return . blob( [ ] )
111+ private final class AggregateDatabaseFunctionDefinition {
112+ let function : any AggregateDatabaseFunction
113+ init ( _ function: some AggregateDatabaseFunction ) {
114+ self . function = function
115+ }
116+ }
117+
118+ private final class AggregateDatabaseFunctionContext {
119+ static subscript( context: OpaquePointer ? ) -> Unmanaged < AggregateDatabaseFunctionContext > {
120+ let size = MemoryLayout< Unmanaged< AggregateDatabaseFunctionContext>>. size
121+ let pointer = sqlite3_aggregate_context ( context, Int32 ( size) ) !
122+ if pointer. load ( as: Int . self) == 0 {
123+ let definition = Unmanaged < AggregateDatabaseFunctionDefinition >
124+ . fromOpaque ( sqlite3_user_data ( context) )
125+ . takeUnretainedValue ( )
126+ let context = AggregateDatabaseFunctionContext ( definition. function)
127+ let unmanagedContext = Unmanaged . passRetained ( context)
128+ pointer
129+ . assumingMemoryBound ( to: Unmanaged< AggregateDatabaseFunctionContext> . self )
130+ . pointee = unmanagedContext
131+ return unmanagedContext
132+ } else {
133+ return
134+ pointer
135+ . assumingMemoryBound ( to: Unmanaged< AggregateDatabaseFunctionContext> . self )
136+ . pointee
137+ }
138+ }
139+ let iterator : any AggregateDatabaseFunctionIteratorProtocol
140+ init ( _ body: some AggregateDatabaseFunction ) {
141+ self . iterator = AggregateDatabaseFunctionIterator ( body)
142+ }
143+ }
144+
145+ private protocol AggregateDatabaseFunctionIteratorProtocol < Body> {
146+ associatedtype Body : AggregateDatabaseFunction
147+
148+ var body : Body { get }
149+ var stream : Stream < Body . Element > { get }
150+ func start( )
151+ func step( _ decoder: inout some QueryDecoder ) throws
152+ func finish( )
153+ var result : QueryBinding { get throws }
154+ }
155+
156+ private final class AggregateDatabaseFunctionIterator <
157+ Body: AggregateDatabaseFunction
158+ > : AggregateDatabaseFunctionIteratorProtocol {
159+ let body : Body
160+ let stream = Stream < Body . Element > ( )
161+ let queue : DispatchQueue
162+ var _result : QueryBinding ?
163+ init ( _ body: Body ) {
164+ self . body = body
165+ self . queue = DispatchQueue (
166+ label: " co.pointfree.StructuredQueriesSQLite.AggregateDatabaseFunction. \( body. name) "
167+ )
168+ nonisolated ( unsafe) let iterator : any AggregateDatabaseFunctionIteratorProtocol = self
169+ queue. async {
170+ iterator. start ( )
171+ }
172+ }
173+ func start( ) {
174+ do {
175+ _result = try body. invoke ( stream)
176+ } catch {
177+ _result = . invalid( error)
178+ }
179+ }
180+ func step( _ decoder: inout some QueryDecoder ) throws {
181+ try stream. send ( body. step ( & decoder) )
182+ }
183+ func finish( ) {
184+ stream. finish ( )
185+ }
186+ var result : QueryBinding {
187+ get throws {
188+ while true {
189+ if let result = queue. sync ( execute: { _result } ) {
190+ return result
84191 }
85- case SQLITE_FLOAT:
86- return . double( sqlite3_value_double ( value) )
87- case SQLITE_INTEGER:
88- return . int( sqlite3_value_int64 ( value) )
89- case SQLITE_NULL:
90- return . null
91- case SQLITE_TEXT:
92- return . text( String ( cString: UnsafePointer ( sqlite3_value_text ( value) ) ) )
93- default :
94- return . invalid( UnknownType ( ) )
95192 }
96193 }
97194 }
195+ }
196+
197+ private final class Stream < Element> : Sequence {
198+ let condition = NSCondition ( )
199+ private var buffer : [ Element ] = [ ]
200+ private var isFinished = false
98201
99- private struct UnknownType : Error { }
202+ func send( _ element: Element ) {
203+ condition. withLock {
204+ buffer. append ( element)
205+ condition. signal ( )
206+ }
207+ }
208+
209+ func finish( ) {
210+ condition. withLock {
211+ isFinished = true
212+ condition. broadcast ( )
213+ }
214+ }
215+
216+ func makeIterator( ) -> Iterator { Iterator ( base: self ) }
217+
218+ struct Iterator : IteratorProtocol {
219+ fileprivate let base : Stream
220+ mutating func next( ) -> Element ? {
221+ base. condition. withLock {
222+ while base. buffer. isEmpty && !base. isFinished {
223+ base. condition. wait ( )
224+ }
225+ guard !base. buffer. isEmpty else { return nil }
226+ return base. buffer. removeFirst ( )
227+ }
228+ }
229+ }
100230}
101231
102232extension QueryBinding {
0 commit comments