Skip to content
181 changes: 181 additions & 0 deletions Sources/Basics/Concurrency/ConcurrencyHelpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Dispatch
import class Foundation.NSLock
import class Foundation.ProcessInfo
import struct Foundation.URL
import struct Foundation.UUID
import func TSCBasic.tsc_await

public enum Concurrency {
Expand Down Expand Up @@ -76,3 +77,183 @@ extension DispatchQueue {
}
}
}

/// A queue for running async operations with a limit on the number of concurrent tasks.
public final class AsyncOperationQueue: @unchecked Sendable {

// This implementation is identical to the AsyncOperationQueue in swift-build.
// Any modifications made here should also be made there.
// https://github.com/swiftlang/swift-build/blob/main/Sources/SWBUtil/AsyncOperationQueue.swift#L13

fileprivate typealias ID = UUID
fileprivate typealias WaitingContinuation = CheckedContinuation<Void, any Error>

private let concurrentTasks: Int
private var activeTasks: Int = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this counter at all or can we just use waitingTasks.count? Keeping the two in sync can be tricky and I had to triple check that the code does it correctly. That requires us to change the WaitingTask to be just WorkTask with a new case but IMO that would make this code a lot clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but it that would mean the waitingTasks would no longer be able to be modelled as a queue since the actively running tasks would be contained within it.

I think an OrderedDictionary could work instead, but runs in to the swift-collections in swift-build question.

private var waitingTasks: [WaitingTask] = []
private let waitingTasksLock = NSLock()

fileprivate enum WaitingTask {
case creating(ID)
case waiting(ID, WaitingContinuation)
case cancelled(ID)

var id: ID {
switch self {
case .creating(let id), .waiting(let id, _), .cancelled(let id):
return id
}
}

var continuation: WaitingContinuation? {
guard case .waiting(_, let continuation) = self else {
return nil
}
return continuation
}
}

/// Creates an `AsyncOperationQueue` with a specified number of concurrent tasks.
/// - Parameter concurrentTasks: The maximum number of concurrent tasks that can be executed concurrently.
public init(concurrentTasks: Int) {
self.concurrentTasks = concurrentTasks
}

deinit {
waitingTasksLock.withLock {
if !waitingTasks.filter({ $0.continuation != nil }).isEmpty {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it unexpected if there is anything at all in waitingTasks, not just those with a non-nil continuation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right, I've updated to revert this check.

preconditionFailure("Deallocated with waiting tasks")
}
}
}

/// Executes an asynchronous operation, ensuring that the number of concurrent tasks
// does not exceed the specified limit.
/// - Parameter operation: The asynchronous operation to execute.
/// - Returns: The result of the operation.
/// - Throws: An error thrown by the operation, or a `CancellationError` if the operation is cancelled.
public func withOperation<ReturnValue>(
_ operation: @Sendable () async throws -> sending ReturnValue
) async throws -> ReturnValue {
try await waitIfNeeded()
defer { signalCompletion() }
return try await operation()
}

private func waitIfNeeded() async throws {
guard waitingTasksLock.withLock({
let shouldWait = activeTasks >= concurrentTasks
activeTasks += 1
return shouldWait
}) else {
return // Less tasks are in flight than the limit.
}

let taskId = ID()
waitingTasksLock.withLock {
waitingTasks.append(.creating(taskId))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move this into the above lock and return an optional UUID. This way we don't need to acquire the lock more than once for the initial setup


enum TaskAction {
case start(WaitingContinuation)
case cancel(WaitingContinuation)
}

try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { (continuation: WaitingContinuation) -> Void in
let action: TaskAction? = waitingTasksLock.withLock {
guard let index = waitingTasks.firstIndex(where: { $0.id == taskId }) else {
// If the task was cancelled in onCancelled it will have been removed from the waiting tasks list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not true right? We are always going to get an index by just looking at the code in onCancel. Either we find a .cancelled case or we have run before. There is no way for onCancel and this check to interleave in the way you describe here. Now, what can interleave is the signalCompletion that also appears to remove cancelled() tasks. Can we update this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've tightend up this comment to reflect how this guard code would actually get called.

return .cancel(continuation)
}

// If the task was cancelled in between creating the task cancellation handler and aquiring the lock,
// we should resume the continuation with a `CancellationError`.
if case .cancelled = waitingTasks[index] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of writing an if case I would prefer if we exhaustively switch over the returned task here similar to how you did it in the onCancel case. I always recommend doing that since it makes sure every case is properly handled.

return .cancel(continuation)
}

// A task may have completed since we iniitally checked if we should wait. Check again in this locked
// section and if we can start it, remove it from the waiting tasks and start it immediately.
let shouldWait = activeTasks >= concurrentTasks
if shouldWait {
waitingTasks[index] = .waiting(taskId, continuation)
return nil
} else {
waitingTasks.remove(at: index)

// activeTasks isn't decremented in the `signalCompletion` method
// when the next task to start is .creating, so we decrement it here.
activeTasks -= 1
return .start(continuation)
}
}

switch action {
case .some(.cancel(let continuation)):
continuation.resume(throwing: _Concurrency.CancellationError())
case .some(.start(let continuation)):
continuation.resume()
case .none:
return
}

}
} onCancel: {
let continuation: WaitingContinuation? = self.waitingTasksLock.withLock {
guard let taskIndex = self.waitingTasks.firstIndex(where: { $0.id == taskId }) else {
return nil
}

switch self.waitingTasks[taskIndex] {
case .waiting(_, let continuation):
self.waitingTasks.remove(at: taskIndex)

// If the parent task is cancelled then we need to manually handle resuming the
// continuation for the waiting task with a `CancellationError`. Return the continuation
// here so it can be resumed once the `waitingTasksLock` is released.
return continuation
case .creating:
// If the task was still being created, mark it as cancelled in the queue so that
// withCheckedThrowingContinuation can immediately cancel it.
self.waitingTasks[taskIndex] = .cancelled(taskId)
activeTasks -= 1
return nil
case .cancelled:
preconditionFailure("Attempting to cancel a task that was already cancelled")
}
}

continuation?.resume(throwing: _Concurrency.CancellationError())
}
}

private func signalCompletion() {
let continuationToResume = waitingTasksLock.withLock { () -> WaitingContinuation? in
guard !waitingTasks.isEmpty else {
activeTasks -= 1
return nil
}

while let lastTask = waitingTasks.first {
switch lastTask {
case .creating:
// If the next task is in the process of being created, let the
return Optional<WaitingContinuation>.none
case .waiting:
activeTasks -= 1
// Begin the next waiting task
return waitingTasks.remove(at: 0).continuation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are operating under FIFO here it is worth using a Deque instead of an array since this remove(at:) is going to reallocate the entire array.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to keep in mind this will make its way back in to swift-build, which has no dependency on swift-collections and looks to be trying to keep its dependencies minimal. @jakepetroules is that accurate, or would swift-build accept a swift-collections dependency for this?

case .cancelled:
// If the next task is cancelled, continue removing cancelled
// tasks until we find one that hasn't run yet or we run out.
_ = waitingTasks.remove(at: 0)
continue
}
}
return nil
}

continuationToResume?.resume()
}
}
24 changes: 9 additions & 15 deletions Sources/PackageGraph/PackageContainer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,33 +197,27 @@ extension PackageContainerConstraint: CustomStringConvertible {
/// An interface for resolving package containers.
public protocol PackageContainerProvider {
/// Get the container for a particular identifier asynchronously.

@available(*, noasync, message: "Use the async alternative")
func getContainer(
for package: PackageReference,
updateStrategy: ContainerUpdateStrategy,
observabilityScope: ObservabilityScope,
on queue: DispatchQueue,
completion: @escaping @Sendable (Result<PackageContainer, Error>) -> Void
)
observabilityScope: ObservabilityScope
) async throws -> PackageContainer
}

public extension PackageContainerProvider {
@available(*, noasync, message: "Use the async alternative")
func getContainer(
for package: PackageReference,
updateStrategy: ContainerUpdateStrategy,
observabilityScope: ObservabilityScope,
on queue: DispatchQueue
) async throws -> PackageContainer {
try await withCheckedThrowingContinuation { continuation in
self.getContainer(
on queue: DispatchQueue,
completion: @escaping @Sendable (Result<PackageContainer, Error>) -> Void
) {
queue.asyncResult(completion) {
try await self.getContainer(
for: package,
updateStrategy: updateStrategy,
observabilityScope: observabilityScope,
on: queue,
completion: {
continuation.resume(with: $0)
}
observabilityScope: observabilityScope
)
}
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/PackageMetadata/PackageMetadata.swift
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,15 @@ public struct PackageSearchClient {
let fetchStandalonePackageByURL = { (error: Error?) async throws -> [Package] in
let url = SourceControlURL(query)
do {
return try withTemporaryDirectory(removeTreeOnDeinit: true) { (tempDir: AbsolutePath) in
return try await withTemporaryDirectory(removeTreeOnDeinit: true) { (tempDir: AbsolutePath) in
let tempPath = tempDir.appending(component: url.lastPathComponent)
let repositorySpecifier = RepositorySpecifier(url: url)
try self.repositoryProvider.fetch(
repository: repositorySpecifier,
to: tempPath,
progressHandler: nil
)
guard try self.repositoryProvider.isValidDirectory(tempPath), let repository = try self.repositoryProvider.open(
guard try self.repositoryProvider.isValidDirectory(tempPath), let repository = try await self.repositoryProvider.open(
repository: repositorySpecifier,
at: tempPath
) as? GitRepository else {
Expand Down
61 changes: 42 additions & 19 deletions Sources/PackageRegistry/RegistryDownloadsManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class RegistryDownloadsManager: AsyncCancellable {
private let path: Basics.AbsolutePath
private let cachePath: Basics.AbsolutePath?
private let registryClient: RegistryClient
private let delegate: Delegate?
private let delegate: RegistryDownloadManagerDelegateProxy?

struct PackageLookup: Hashable {
let package: PackageIdentity
Expand All @@ -48,14 +48,13 @@ public class RegistryDownloadsManager: AsyncCancellable {
self.path = path
self.cachePath = cachePath
self.registryClient = registryClient
self.delegate = delegate
self.delegate = RegistryDownloadManagerDelegateProxy(delegate)
}

public func lookup(
package: PackageIdentity,
version: Version,
observabilityScope: ObservabilityScope,
delegateQueue: DispatchQueue
observabilityScope: ObservabilityScope
) async throws -> Basics.AbsolutePath {
let packageRelativePath: Basics.RelativePath
let packagePath: Basics.AbsolutePath
Expand All @@ -82,9 +81,9 @@ public class RegistryDownloadsManager: AsyncCancellable {
// inform delegate that we are starting to fetch
// calculate if cached (for delegate call) outside queue as it may change while queue is processing
let isCached = self.cachePath.map { self.fileSystem.exists($0.appending(packageRelativePath)) } ?? false
delegateQueue.async { [delegate = self.delegate] in
Task {
let details = FetchDetails(fromCache: isCached, updatedCache: false)
delegate?.willFetch(package: package, version: version, fetchDetails: details)
await delegate?.willFetch(package: package, version: version, fetchDetails: details)
}

// make sure destination is free.
Expand All @@ -96,18 +95,17 @@ public class RegistryDownloadsManager: AsyncCancellable {
package: package,
version: version,
packagePath: packagePath,
observabilityScope: observabilityScope,
delegateQueue: delegateQueue
observabilityScope: observabilityScope
)
// inform delegate that we finished to fetch
let duration = start.distance(to: .now())
delegateQueue.async { [delegate = self.delegate] in
delegate?.didFetch(package: package, version: version, result: .success(result), duration: duration)
Task {
await delegate?.didFetch(package: package, version: version, result: .success(result), duration: duration)
}
} catch {
let duration = start.distance(to: .now())
delegateQueue.async { [delegate = self.delegate] in
delegate?.didFetch(package: package, version: version, result: .failure(error), duration: duration)
Task {
await delegate?.didFetch(package: package, version: version, result: .failure(error), duration: duration)
}
throw error
}
Expand All @@ -126,16 +124,14 @@ public class RegistryDownloadsManager: AsyncCancellable {
package: PackageIdentity,
version: Version,
observabilityScope: ObservabilityScope,
delegateQueue: DispatchQueue,
callbackQueue: DispatchQueue,
completion: @escaping @Sendable (Result<Basics.AbsolutePath, Error>) -> Void
) {
callbackQueue.asyncResult(completion) {
try await self.lookup(
package: package,
version: version,
observabilityScope: observabilityScope,
delegateQueue: delegateQueue
observabilityScope: observabilityScope
)
}
}
Expand All @@ -149,8 +145,7 @@ public class RegistryDownloadsManager: AsyncCancellable {
package: PackageIdentity,
version: Version,
packagePath: Basics.AbsolutePath,
observabilityScope: ObservabilityScope,
delegateQueue: DispatchQueue
observabilityScope: ObservabilityScope
) async throws -> FetchDetails {
if let cachePath {
do {
Expand Down Expand Up @@ -238,8 +233,8 @@ public class RegistryDownloadsManager: AsyncCancellable {
// utility to update progress

@Sendable func updateDownloadProgress(downloaded: Int64, total: Int64?) {
delegateQueue.async { [delegate = self.delegate] in
delegate?.fetching(
Task {
await delegate?.fetching(
package: package,
version: version,
bytesDownloaded: downloaded,
Expand Down Expand Up @@ -327,6 +322,34 @@ public protocol RegistryDownloadsManagerDelegate: Sendable {
func fetching(package: PackageIdentity, version: Version, bytesDownloaded: Int64, totalBytesToDownload: Int64?)
}

actor RegistryDownloadManagerDelegateProxy {
private let delegate: RegistryDownloadsManagerDelegate

init?(_ delegate: RegistryDownloadsManagerDelegate?) {
guard let delegate else {
return nil
}
self.delegate = delegate
}

func willFetch(package: PackageIdentity, version: Version, fetchDetails: RegistryDownloadsManager.FetchDetails) {
self.delegate.willFetch(package: package, version: version, fetchDetails: fetchDetails)
}

func didFetch(
package: PackageIdentity,
version: Version,
result: Result<RegistryDownloadsManager.FetchDetails, Error>,
duration: DispatchTimeInterval
) {
self.delegate.didFetch(package: package, version: version, result: result, duration: duration)
}

func fetching(package: PackageIdentity, version: Version, bytesDownloaded: Int64, totalBytesToDownload: Int64?) {
self.delegate.fetching(package: package, version: version, bytesDownloaded: bytesDownloaded, totalBytesToDownload: totalBytesToDownload)
}
}

extension Dictionary where Key == RegistryDownloadsManager.PackageLookup {
fileprivate mutating func removeValue(forPackage package: PackageIdentity) {
self.keys
Expand Down
Loading