From f34d4efb3768f6c0dd836e0626727226747d0b9b Mon Sep 17 00:00:00 2001 From: Quahu Date: Sat, 3 Aug 2024 19:06:45 +0200 Subject: [PATCH] Refactor VoiceExtension API --- .../Voice/BasicVoice/AudioPlayerService.cs | 9 +- .../VoiceExtension.cs | 120 +++++++++++++++--- 2 files changed, 108 insertions(+), 21 deletions(-) diff --git a/examples/Voice/BasicVoice/AudioPlayerService.cs b/examples/Voice/BasicVoice/AudioPlayerService.cs index 84c25de23..f9e0cce83 100644 --- a/examples/Voice/BasicVoice/AudioPlayerService.cs +++ b/examples/Voice/BasicVoice/AudioPlayerService.cs @@ -28,13 +28,17 @@ public override async Task StopAsync(CancellationToken cancellationToken) { await base.StopAsync(cancellationToken); + var voiceExtension = Bot.GetRequiredExtension(); + await _semaphore.WaitAsync(cancellationToken); try { - foreach (var (_, player) in _players) + foreach (var (guildID, player) in _players) { player.Stop(); await player.DisposeAsync(); + + await voiceExtension.DisconnectAsync(guildID); } _players.Clear(); @@ -111,6 +115,9 @@ public async Task DisposePlayerAsync(Snowflake guildId) { player.Stop(); await player.DisposeAsync(); + + var voiceExtension = Bot.GetRequiredExtension(); + await voiceExtension.DisconnectAsync(guildId); } } finally diff --git a/src/Disqord.Extensions.Voice/VoiceExtension.cs b/src/Disqord.Extensions.Voice/VoiceExtension.cs index 7e8f49969..a2e27c990 100644 --- a/src/Disqord.Extensions.Voice/VoiceExtension.cs +++ b/src/Disqord.Extensions.Voice/VoiceExtension.cs @@ -1,22 +1,23 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Disqord.Gateway; +using Disqord.Utilities.Threading; using Disqord.Voice; using Microsoft.Extensions.Logging; using Qommon; -using Qommon.Collections.ReadOnly; using Qommon.Collections.ThreadSafe; namespace Disqord.Extensions.Voice; +// TODO: keyed semaphore public class VoiceExtension : DiscordClientExtension { - public IReadOnlyDictionary Connections => _connections.ReadOnly(); - private readonly IVoiceConnectionFactory _connectionFactory; - private readonly IThreadSafeDictionary _connections; + private readonly IThreadSafeDictionary _connections; public VoiceExtension( ILogger logger, @@ -25,7 +26,7 @@ public VoiceExtension( { _connectionFactory = connectionFactory; - _connections = ThreadSafeDictionary.Monitor.Create(); + _connections = ThreadSafeDictionary.Monitor.Create(); } /// @@ -39,28 +40,59 @@ protected override ValueTask InitializeAsync(CancellationToken cancellationToken private Task VoiceServerUpdatedAsync(object? sender, VoiceServerUpdatedEventArgs e) { - if (_connections.TryGetValue(e.GuildId, out var connection)) - { - connection.OnVoiceServerUpdate(e.Token, e.Endpoint); - } - + GetConnection(e.GuildId)?.OnVoiceServerUpdate(e.Token, e.Endpoint); return Task.CompletedTask; } private Task VoiceStateUpdatedAsync(object? sender, VoiceStateUpdatedEventArgs e) { if (Client.CurrentUser.Id != e.NewVoiceState.MemberId) - return Task.CompletedTask; - - if (_connections.TryGetValue(e.GuildId, out var connection)) { - var voiceState = e.NewVoiceState; - connection.OnVoiceStateUpdate(voiceState.ChannelId, voiceState.SessionId); + return Task.CompletedTask; } + var voiceState = e.NewVoiceState; + GetConnection(e.GuildId)?.OnVoiceStateUpdate(voiceState.ChannelId, voiceState.SessionId); return Task.CompletedTask; } + /// + /// Gets the voice connection for the guild with the given ID. + /// + /// The ID of the guild. + /// + /// The voice connection or if the connection does not exist. + /// + public IVoiceConnection? GetConnection(Snowflake guildId) + { + return _connections.TryGetValue(guildId, out var connectionInfo) + ? connectionInfo.Connection + : null; + } + + /// + /// Gets all maintained voice connections. + /// + /// + /// A dictionary of all connections keyed by the IDs of the guilds. + /// + public IReadOnlyDictionary GetConnections() + { + return _connections.ToDictionary(static kvp => kvp.Key, static kvp => kvp.Value.Connection); + } + + /// + /// Connects to the channel with the given ID. + /// + /// + /// To disconnect the voice connection, use . + /// + /// The ID of the guild the channel is in. + /// The ID of the channel. + /// The cancellation token to observe. This is only used for the initial connection. + /// + /// A with the result being the created voice connection. + /// public async ValueTask ConnectAsync(Snowflake guildId, Snowflake channelId, CancellationToken cancellationToken = default) { var connection = _connectionFactory.Create(guildId, channelId, Client.CurrentUser.Id, @@ -74,12 +106,60 @@ public async ValueTask ConnectAsync(Snowflake guildId, Snowfla return new(shard.SetVoiceStateAsync(guildId, channelId, false, true, cancellationToken)); }); - _connections[guildId] = connection; + var connectionInfo = new VoiceConnectionInfo(connection, Cts.Linked(Client.StoppingToken)); + _connections[guildId] = connectionInfo; + try + { + var readyTask = connection.WaitUntilReadyAsync(cancellationToken); + _ = connection.RunAsync(connectionInfo.Cts.Token); - var readyTask = connection.WaitUntilReadyAsync(cancellationToken); - _ = connection.RunAsync(Client.StoppingToken); + await readyTask.ConfigureAwait(false); + } + catch + { + _connections.Remove(guildId); + await connectionInfo.DisposeAsync(); + throw; + } - await readyTask.ConfigureAwait(false); return connection; } + + /// + /// Disconnects the voice connection, if one exists, for the guild with the given ID. + /// + /// + /// This will render the relevant voice connection unusable. + /// Use to obtain a new connection afterward. + /// + /// The ID of the guild. + public ValueTask DisconnectAsync(Snowflake guildId) + { + if (!_connections.TryRemove(guildId, out var connectionInfo)) + { + return default; + } + + return connectionInfo.DisposeAsync(); + } + + private readonly struct VoiceConnectionInfo : IAsyncDisposable + { + public IVoiceConnection Connection { get; } + + public Cts Cts { get; } + + public VoiceConnectionInfo(IVoiceConnection connection, Cts cts) + { + Connection = connection; + Cts = cts; + } + + public async ValueTask DisposeAsync() + { + Cts.Cancel(); + Cts.Dispose(); + await Connection.DisposeAsync(); + } + } }