Skip to content

Commit c1539d3

Browse files
authored
Merge pull request #4 from wickwirew/manually-run-migrations
Add ability to run migrations manually
2 parents c0e918e + cb48c3d commit c1539d3

File tree

7 files changed

+126
-20
lines changed

7 files changed

+126
-20
lines changed

Sources/PureSQL/Connection.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,22 @@ public protocol Connection: Sendable {
1616
/// Cancels the observation for the given subscriber
1717
func cancel(subscriber: DatabaseSubscriber)
1818

19+
/// Begins a transaction and passes it to the `execute` function.
20+
/// If no error is thrown the changes are automatically commited.
21+
/// If an error is thrown the changes are rolled back.
1922
func begin<Output>(
2023
_ kind: Transaction.Kind,
2124
execute: @Sendable (borrowing Transaction) throws -> Output
2225
) async throws -> Output
26+
27+
/// Gets a raw connection to the database and allows for direct
28+
/// SQL access. No transaction is automatically started.
29+
///
30+
/// This is likely not the API you want, and should just use `begin`.
31+
func withConnection<Output>(
32+
isWrite: Bool,
33+
execute: @Sendable (borrowing RawConnection) throws -> Output
34+
) async throws -> Output
2335
}
2436

2537
/// A no operation database connection that does nothing.
@@ -36,6 +48,13 @@ public struct NoopConnection: Connection {
3648
) async throws -> Output {
3749
try execute(Transaction(connection: NoopRawConnection(), kind: kind))
3850
}
51+
52+
public func withConnection<Output>(
53+
isWrite: Bool,
54+
execute: @Sendable (borrowing RawConnection) throws -> Output
55+
) async throws -> Output {
56+
try execute(NoopRawConnection())
57+
}
3958
}
4059

4160
/// A type that has a database connection.
@@ -62,4 +81,11 @@ public extension ConnectionWrapper {
6281
) async throws -> Output {
6382
try await connection.begin(kind, execute: execute)
6483
}
84+
85+
func withConnection<Output>(
86+
isWrite: Bool,
87+
execute: @Sendable (borrowing RawConnection) throws -> Output
88+
) async throws -> Output {
89+
try await connection.withConnection(isWrite: isWrite, execute: execute)
90+
}
6591
}

Sources/PureSQL/ConnectionPool.swift

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ public actor ConnectionPool: Sendable {
3636
public init(
3737
path: String,
3838
limit: Int,
39-
migrations: [String]
39+
migrations: [String],
40+
runMigrations: Bool = true
4041
) throws {
4142
guard limit > 0 else {
4243
throw SQLError.poolCannotHaveZeroConnections
@@ -50,7 +51,11 @@ public actor ConnectionPool: Sendable {
5051

5152
// Turn on WAL mode
5253
try connection.execute(sql: "PRAGMA journal_mode=WAL;")
53-
try MigrationRunner.execute(migrations: migrations, connection: connection)
54+
55+
if runMigrations {
56+
try MigrationRunner.execute(migrations: migrations, connection: connection)
57+
}
58+
5459
self.availableConnections = [connection]
5560
}
5661

@@ -63,26 +68,32 @@ public actor ConnectionPool: Sendable {
6368
private func begin(
6469
_ kind: Transaction.Kind
6570
) async throws(SQLError) -> sending Transaction {
66-
// Writes must be exclusive, make sure to wait on any pending writes.
67-
if kind == .write {
68-
await writeLock.lock()
69-
}
70-
71-
return try await Transaction(connection: getConnection(), kind: kind)
71+
return try await Transaction(
72+
connection: getConnection(isWrite: kind == .write),
73+
kind: kind
74+
)
7275
}
7376

7477
/// Gives the connection back to the pool.
75-
private func reclaim(connection: RawConnection, kind: Transaction.Kind) async {
78+
private func reclaim(
79+
connection: RawConnection,
80+
isWrite: Bool
81+
) async {
7682
availableConnections.append(connection)
7783
alertAnyWaitersOfAvailableConnection()
7884

79-
if kind == .write {
85+
if isWrite {
8086
await writeLock.unlock()
8187
}
8288
}
8389

8490
/// Will get, wait or create a connection to the database
85-
private func getConnection() async throws(SQLError) -> RawConnection {
91+
private func getConnection(isWrite: Bool) async throws(SQLError) -> RawConnection {
92+
// Writes must be exclusive, make sure to wait on any pending writes.
93+
if isWrite {
94+
await writeLock.lock()
95+
}
96+
8697
guard availableConnections.isEmpty else {
8798
// Have an available connection, just use it
8899
return availableConnections.removeLast()
@@ -169,11 +180,22 @@ extension ConnectionPool: Connection {
169180

170181
do {
171182
let output = try await execute(tx)
172-
await reclaim(connection: conn, kind: kind)
183+
await reclaim(connection: conn, isWrite: kind == .write)
173184
return output
174185
} catch {
175-
await reclaim(connection: conn, kind: kind)
186+
await reclaim(connection: conn, isWrite: kind == .write)
176187
throw error
177188
}
178189
}
190+
191+
/// Gets a connection to the database. No tx is started.
192+
public func withConnection<Output: Sendable>(
193+
isWrite: Bool,
194+
execute: @Sendable (borrowing RawConnection) throws -> Output
195+
) async throws -> Output {
196+
let conn = try await getConnection(isWrite: isWrite)
197+
let output = try execute(conn)
198+
await reclaim(connection: conn, isWrite: isWrite)
199+
return output
200+
}
179201
}

Sources/PureSQL/Database.swift

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@ public extension Database {
4848
try ConnectionPool(
4949
path: path,
5050
limit: config.maxConnectionCount,
51-
migrations: Self.sanitizedMigrations
51+
migrations: Self.sanitizedMigrations,
52+
runMigrations: config.autoMigrate
5253
)
5354
} else {
5455
try ConnectionPool(
5556
path: ":memory:",
5657
limit: 1,
57-
migrations: Self.sanitizedMigrations
58+
migrations: Self.sanitizedMigrations,
59+
runMigrations: config.autoMigrate
5860
)
5961
}
6062

@@ -65,6 +67,20 @@ public extension Database {
6567
static func inMemory(adapters: Adapters) throws -> Self {
6668
return try Self(config: DatabaseConfig(path: nil), adapters: adapters)
6769
}
70+
71+
/// Runs the migrations up to and including the `maxMigration`.
72+
///
73+
/// The `maxMigration` number is not equal to the filename, but
74+
/// rather the zero based index.
75+
func migrate(upTo maxMigration: Int? = nil) async throws {
76+
try await connection.withConnection(isWrite: true) { conn in
77+
try MigrationRunner.execute(
78+
migrations: Self.sanitizedMigrations,
79+
connection: conn,
80+
upTo: maxMigration
81+
)
82+
}
83+
}
6884
}
6985

7086

Sources/PureSQL/DatabaseConfig.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@ public struct DatabaseConfig {
1515
/// In memory databases will be overriden to `1` regardless
1616
/// of the input
1717
public var maxConnectionCount: Int
18+
/// If `true` the migrations will run when the connection is opened.
19+
/// Default is `true`.
20+
public var autoMigrate: Bool
1821

1922
public init(
2023
path: String?,
21-
maxConnectionCount: Int = 5
24+
maxConnectionCount: Int = 5,
25+
autoMigrate: Bool = true
2226
) {
2327
self.path = path
2428
self.maxConnectionCount = maxConnectionCount
29+
self.autoMigrate = autoMigrate
2530
}
2631
}

Sources/PureSQL/Migration.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,21 @@
99
enum MigrationRunner {
1010
static let migrationTableName = "__puresqlMigration"
1111

12-
static func execute(migrations: [String], connection: SQLiteConnection) throws {
13-
let previouslyRunMigrations = try runMigrations(connection: connection)
12+
static func execute(
13+
migrations: [String],
14+
connection: RawConnection,
15+
upTo maxMigration: Int? = nil
16+
) throws {
17+
let previouslyRunMigrations = try getRanMigrations(connection: connection)
1418
let lastMigration = previouslyRunMigrations.last ?? Int.min
1519

1620
let pendingMigrations = migrations.enumerated()
1721
.map { (number: $0.offset, migration: $0.element) }
1822
.filter { $0.number > lastMigration }
1923

2024
for (number, migration) in pendingMigrations {
25+
if let maxMigration, number > maxMigration { return }
26+
2127
// Run each migration in it's own transaction.
2228
let tx = try Transaction(connection: connection, kind: .write)
2329

@@ -50,7 +56,7 @@ enum MigrationRunner {
5056
}
5157

5258
/// Creates the migrations table and gets the last migration that ran.
53-
private static func runMigrations(connection: SQLiteConnection) throws -> [Int] {
59+
private static func getRanMigrations(connection: RawConnection) throws -> [Int] {
5460
let tx = try Transaction(connection: connection, kind: .write)
5561

5662
// Create the migration table if need be.

Sources/PureSQL/SQLiteConnection.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ import Collections
99
import Foundation
1010
import SQLite3
1111

12-
protocol RawConnection: Sendable {
12+
/// Represents a raw connection to the SQLite database
13+
public protocol RawConnection: Sendable {
14+
/// Initializes a SQLite prepared statement
1315
func prepare(sql: String) throws(SQLError) -> OpaquePointer
16+
/// Executes the SQL statement.
17+
/// Equivalent to `sqlite3_exec`
1418
func execute(sql: String) throws(SQLError)
1519
}
1620

Tests/PureSQLTests/MigrationRunnerTests.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,33 @@ struct MigrationRunnerTests: ~Copyable {
108108
#expect(tables.contains("foo"))
109109
}
110110

111+
@Test func canRunMigrationsUpToCertainNumber() async throws {
112+
let migrations = [
113+
"CREATE TABLE foo (bar INTEGER);",
114+
"CREATE TABLE bar (baz TEXT);",
115+
]
116+
117+
try MigrationRunner.execute(
118+
migrations: migrations,
119+
connection: connection,
120+
upTo: 0 // Dont run the last migration
121+
)
122+
123+
let onlyFirstMigration = try tableNames()
124+
#expect(onlyFirstMigration.contains("foo"))
125+
#expect(!onlyFirstMigration.contains("bar"))
126+
127+
try MigrationRunner.execute(
128+
migrations: migrations,
129+
connection: connection,
130+
upTo: nil // Now run all migrations
131+
)
132+
133+
let allMigrations = try tableNames()
134+
#expect(allMigrations.contains("foo"))
135+
#expect(allMigrations.contains("bar"))
136+
}
137+
111138
private func runMigrations() throws -> [Int] {
112139
return try query("SELECT * FROM \(MigrationRunner.migrationTableName) ORDER BY number ASC") { try $0.fetchAll() }
113140
}

0 commit comments

Comments
 (0)