Skip to content

Commit c3a83a3

Browse files
committed
Add GraphEnvironmentTrait
1 parent 47e75dc commit c3a83a3

File tree

5 files changed

+310
-29
lines changed

5 files changed

+310
-29
lines changed

Tests/OpenGraphCompatibilityTests/Attribute/Attribute/AttributeCompatibilityTests.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import Testing
66

77
#if canImport(Darwin)
8-
@Suite(.disabled(if: !compatibilityTestEnabled, "Attribute is not implemented"))
9-
final class AttributeCompatibilityTests: AttributeTestBase {
8+
@MainActor
9+
@Suite(.disabled(if: !compatibilityTestEnabled, "Attribute is not implemented"), .graphScope)
10+
struct AttributeCompatibilityTests {
1011
@Test
1112
func initWithValue() {
1213
let intAttribute = Attribute(value: 0)

Tests/OpenGraphCompatibilityTests/Attribute/AttributeTestBase.swift

Lines changed: 0 additions & 24 deletions
This file was deleted.
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
// Copyright (C) 2022 Gwendal Roué
2+
//
3+
// Permission is hereby granted, free of charge, to any person obtaining a
4+
// copy of this software and associated documentation files (the
5+
// "Software"), to deal in the Software without restriction, including
6+
// without limitation the rights to use, copy, modify, merge, publish,
7+
// distribute, sublicense, and/or sell copies of the Software, and to
8+
// permit persons to whom the Software is furnished to do so, subject to
9+
// the following conditions:
10+
//
11+
// The above copyright notice and this permission notice shall be included
12+
// in all copies or substantial portions of the Software.
13+
//
14+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15+
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16+
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
17+
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
18+
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
19+
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
20+
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21+
22+
import Foundation
23+
24+
/// An object that controls access to a resource across multiple execution
25+
/// contexts through use of a traditional counting semaphore.
26+
///
27+
/// You increment a semaphore count by calling the ``signal()`` method, and
28+
/// decrement a semaphore count by calling ``wait()`` or one of its variants.
29+
///
30+
/// ## Topics
31+
///
32+
/// ### Creating a Semaphore
33+
///
34+
/// - ``init(value:)``
35+
///
36+
/// ### Signaling the Semaphore
37+
///
38+
/// - ``signal()``
39+
///
40+
/// ### Waiting for the Semaphore
41+
///
42+
/// - ``wait()``
43+
/// - ``waitUnlessCancelled()``
44+
public final class AsyncSemaphore: @unchecked Sendable {
45+
/// `Suspension` is the state of a task waiting for a signal.
46+
///
47+
/// It is a class because instance identity helps `waitUnlessCancelled()`
48+
/// deal with both early and late cancellation.
49+
///
50+
/// We make it @unchecked Sendable in order to prevent compiler warnings:
51+
/// instances are always protected by the semaphore's lock.
52+
private class Suspension: @unchecked Sendable {
53+
enum State {
54+
/// Initial state. Next is suspendedUnlessCancelled, or cancelled.
55+
case pending
56+
57+
/// Waiting for a signal, with support for cancellation.
58+
case suspendedUnlessCancelled(UnsafeContinuation<Void, Error>)
59+
60+
/// Waiting for a signal, with no support for cancellation.
61+
case suspended(UnsafeContinuation<Void, Never>)
62+
63+
/// Cancelled before we have started waiting.
64+
case cancelled
65+
}
66+
67+
var state: State
68+
69+
init(state: State) {
70+
self.state = state
71+
}
72+
}
73+
74+
// MARK: - Internal State
75+
76+
/// The semaphore value.
77+
private var value: Int
78+
79+
/// As many elements as there are suspended tasks waiting for a signal.
80+
private var suspensions: [Suspension] = []
81+
82+
/// The lock that protects `value` and `suspensions`.
83+
///
84+
/// It is recursive in order to handle cancellation (see the implementation
85+
/// of ``waitUnlessCancelled()``).
86+
private let _lock = NSRecursiveLock()
87+
88+
// MARK: - Creating a Semaphore
89+
90+
/// Creates a semaphore.
91+
///
92+
/// - parameter value: The starting value for the semaphore. Do not pass a
93+
/// value less than zero.
94+
public init(value: Int) {
95+
precondition(value >= 0, "AsyncSemaphore requires a value equal or greater than zero")
96+
self.value = value
97+
}
98+
99+
deinit {
100+
precondition(suspensions.isEmpty, "AsyncSemaphore is deallocated while some task(s) are suspended waiting for a signal.")
101+
}
102+
103+
// MARK: - Locking
104+
105+
// Let's hide the locking primitive in order to avoid a compiler warning:
106+
//
107+
// > Instance method 'lock' is unavailable from asynchronous contexts;
108+
// > Use async-safe scoped locking instead; this is an error in Swift 6.
109+
//
110+
// We're not sweeping bad stuff under the rug. We really need to protect
111+
// our inner state (`value` and `suspension`) across the calls to
112+
// `withUnsafeContinuation`. Unfortunately, this method introduces a
113+
// suspension point. So we need a lock.
114+
private func lock() { _lock.lock() }
115+
private func unlock() { _lock.unlock() }
116+
117+
// MARK: - Waiting for the Semaphore
118+
119+
/// Waits for, or decrements, a semaphore.
120+
///
121+
/// Decrement the counting semaphore. If the resulting value is less than
122+
/// zero, this function suspends the current task until a signal occurs,
123+
/// without blocking the underlying thread. Otherwise, no suspension happens.
124+
public func wait() async {
125+
lock()
126+
127+
value -= 1
128+
if value >= 0 {
129+
unlock()
130+
return
131+
}
132+
133+
await withUnsafeContinuation { continuation in
134+
// Register the continuation that `signal` will resume.
135+
let suspension = Suspension(state: .suspended(continuation))
136+
suspensions.insert(suspension, at: 0) // FIFO
137+
unlock()
138+
}
139+
}
140+
141+
/// Waits for, or decrements, a semaphore, with support for cancellation.
142+
///
143+
/// Decrement the counting semaphore. If the resulting value is less than
144+
/// zero, this function suspends the current task until a signal occurs,
145+
/// without blocking the underlying thread. Otherwise, no suspension happens.
146+
///
147+
/// If the task is canceled before a signal occurs, this function
148+
/// throws `CancellationError`.
149+
public func waitUnlessCancelled() async throws {
150+
lock()
151+
152+
value -= 1
153+
if value >= 0 {
154+
defer { unlock() }
155+
156+
do {
157+
// All code paths check for cancellation
158+
try Task.checkCancellation()
159+
} catch {
160+
// Cancellation is like a signal: we don't really "consume"
161+
// the semaphore, and restore the value.
162+
value += 1
163+
throw error
164+
}
165+
166+
return
167+
}
168+
169+
// Get ready for being suspended waiting for a continuation, or for
170+
// early cancellation.
171+
let suspension = Suspension(state: .pending)
172+
173+
try await withTaskCancellationHandler {
174+
try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation<Void, Error>) in
175+
if case .cancelled = suspension.state {
176+
// Early cancellation: waitUnlessCancelled() is called from
177+
// a cancelled task, and the `onCancel` closure below
178+
// has marked the suspension as cancelled.
179+
// Resume with a CancellationError.
180+
unlock()
181+
continuation.resume(throwing: CancellationError())
182+
} else {
183+
// Current task is not cancelled: register the continuation
184+
// that `signal` will resume.
185+
suspension.state = .suspendedUnlessCancelled(continuation)
186+
suspensions.insert(suspension, at: 0) // FIFO
187+
unlock()
188+
}
189+
}
190+
} onCancel: {
191+
// withTaskCancellationHandler may immediately call this block (if
192+
// the current task is cancelled), or call it later (if the task is
193+
// cancelled later). In the first case, we're still holding the lock,
194+
// waiting for the continuation. In the second case, we do not hold
195+
// the lock. Being able to handle both situations is the reason why
196+
// we use a recursive lock.
197+
lock()
198+
199+
// We're no longer waiting for a signal
200+
value += 1
201+
if let index = suspensions.firstIndex(where: { $0 === suspension }) {
202+
suspensions.remove(at: index)
203+
}
204+
205+
if case let .suspendedUnlessCancelled(continuation) = suspension.state {
206+
// Late cancellation: the task is cancelled while waiting
207+
// from the semaphore. Resume with a CancellationError.
208+
unlock()
209+
continuation.resume(throwing: CancellationError())
210+
} else {
211+
// Early cancellation: waitUnlessCancelled() is called from
212+
// a cancelled task.
213+
//
214+
// The next step is the `withTaskCancellationHandler`
215+
// operation closure right above.
216+
suspension.state = .cancelled
217+
unlock()
218+
}
219+
}
220+
}
221+
222+
// MARK: - Signaling the Semaphore
223+
224+
/// Signals (increments) a semaphore.
225+
///
226+
/// Increment the counting semaphore. If the previous value was less than
227+
/// zero, this function resumes a task currently suspended in ``wait()``
228+
/// or ``waitUnlessCancelled()``.
229+
///
230+
/// - returns: This function returns true if a suspended task is
231+
/// resumed. Otherwise, the result is false, meaning that no task was
232+
/// waiting for the semaphore.
233+
@discardableResult
234+
public func signal() -> Bool {
235+
lock()
236+
237+
value += 1
238+
239+
switch suspensions.popLast()?.state { // FIFO
240+
case let .suspendedUnlessCancelled(continuation):
241+
unlock()
242+
continuation.resume()
243+
return true
244+
case let .suspended(continuation):
245+
unlock()
246+
continuation.resume()
247+
return true
248+
default:
249+
unlock()
250+
return false
251+
}
252+
}
253+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//
2+
// AttributeTestBase.swift
3+
// OpenGraphCompatibilityTests
4+
5+
import Testing
6+
import Foundation
7+
8+
/// Base class for Attribute Related test case
9+
@available(*, deprecated, message: "Use GraphEnvironmentTrait instead")
10+
class AttributeTestBase {
11+
private static let sharedGraph = Graph()
12+
private var graph: Graph
13+
private var subgraph: Subgraph
14+
15+
init() {
16+
graph = Graph(shared: Self.sharedGraph)
17+
subgraph = Subgraph(graph: graph)
18+
Subgraph.current = subgraph
19+
}
20+
21+
deinit {
22+
Subgraph.current = nil
23+
}
24+
}
25+
26+
struct GraphEnvironmentTrait: TestTrait, TestScoping, SuiteTrait {
27+
private static let sharedGraph = Graph()
28+
private let semaphore = AsyncSemaphore(value: 1)
29+
30+
31+
@MainActor
32+
func provideScope(for test: Test, testCase: Test.Case?, performing function: @Sendable () async throws -> Void) async throws {
33+
await semaphore.wait()
34+
defer { semaphore.signal() }
35+
let graph = Graph(shared: Self.sharedGraph)
36+
let subgraph = Subgraph(graph: graph)
37+
let oldSubgraph = Subgraph.current
38+
39+
Subgraph.current = subgraph
40+
try await function()
41+
Subgraph.current = oldSubgraph
42+
}
43+
44+
var isRecursive: Bool {
45+
true
46+
}
47+
}
48+
49+
extension Trait where Self == GraphEnvironmentTrait {
50+
static var graphScope: Self {
51+
GraphEnvironmentTrait()
52+
}
53+
}

Tests/OpenGraphCompatibilityTests/Attribute/AttributeTestHelper.swift renamed to Tests/OpenGraphCompatibilityTests/Helper/AttributeTestHelper.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
//
22
// AttributeTestHelper.swift
3-
//
4-
//
5-
//
3+
// OpenGraphCompatibilityTests
64

75
struct Tuple<A, B> {
86
var first: A

0 commit comments

Comments
 (0)