Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
//

using System;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -29,6 +31,7 @@
using Microsoft.Extensions.Options;
using Remora.Discord.Caching.Abstractions;
using Remora.Discord.Caching.Abstractions.Services;
using Remora.Discord.Rest;
using Remora.Results;

namespace Remora.Discord.Caching.Redis.Services;
Expand All @@ -41,16 +44,40 @@ public class RedisCacheProvider : ICacheProvider
{
private readonly IDistributedCache _cache;
private readonly JsonSerializerOptions _jsonOptions;
private readonly string? _tokenHash;

/// <summary>
/// Initializes a new instance of the <see cref="RedisCacheProvider"/> class.
/// </summary>
/// <param name="cache">The redis cache.</param>
/// <param name="jsonOptions">The JSON options.</param>
public RedisCacheProvider(IDistributedCache cache, IOptionsMonitor<JsonSerializerOptions> jsonOptions)
/// <param name="tokenStore">The token store, if one is available.</param>
public RedisCacheProvider
(
IDistributedCache cache,
IOptionsMonitor<JsonSerializerOptions> jsonOptions,
ITokenStore? tokenStore = null
)
{
_cache = cache;
_jsonOptions = jsonOptions.Get("Discord");

if (tokenStore is null)
{
_tokenHash = null;
return;
}

using var hasher = SHA256.Create();
var hashBuilder = new StringBuilder(64);
var hash = hasher.ComputeHash(Encoding.UTF8.GetBytes(tokenStore.Token));

foreach (var value in hash)
{
hashBuilder.Append(value.ToString("x2"));
}

_tokenHash = hashBuilder.ToString();
}

/// <inheritdoc cref="ICacheProvider.CacheAsync{TInstance}"/>
Expand All @@ -77,7 +104,7 @@ public virtual async ValueTask CacheAsync<TInstance>

var serialized = JsonSerializer.SerializeToUtf8Bytes(instance, _jsonOptions);

await _cache.SetAsync(key.ToCanonicalString(), serialized, options, ct);
await _cache.SetAsync(CreateTokenScopedKey(key), serialized, options, ct);
}

/// <inheritdoc cref="ICacheProvider.RetrieveAsync{TInstance}"/>
Expand All @@ -95,7 +122,7 @@ public virtual async ValueTask<Result<TInstance>> RetrieveAsync<TInstance>
)
where TInstance : class
{
var keyString = key.ToCanonicalString();
var keyString = CreateTokenScopedKey(key);

var value = await _cache.GetAsync(keyString, ct);

Expand All @@ -114,7 +141,7 @@ public virtual async ValueTask<Result<TInstance>> RetrieveAsync<TInstance>
/// <inheritdoc cref="ICacheProvider.EvictAsync" />
public async ValueTask<Result> EvictAsync(CacheKey key, CancellationToken ct = default)
{
var keyString = key.ToCanonicalString();
var keyString = CreateTokenScopedKey(key);

var existingValue = await _cache.GetAsync(keyString, ct);

Expand Down Expand Up @@ -143,7 +170,7 @@ public virtual async ValueTask<Result<TInstance>> EvictAsync<TInstance>
)
where TInstance : class
{
var keyString = key.ToCanonicalString();
var keyString = CreateTokenScopedKey(key);

var existingValue = await _cache.GetAsync(keyString, ct);

Expand All @@ -158,4 +185,13 @@ public virtual async ValueTask<Result<TInstance>> EvictAsync<TInstance>

return deserialized;
}

/// <summary>
/// Creates a cache key scoped to a specific token.
/// </summary>
/// <param name="key">The key.</param>
/// <returns>The scoped key.</returns>
private string CreateTokenScopedKey(CacheKey key) => _tokenHash is not null
? $"{_tokenHash}:{key.ToCanonicalString()}"
: key.ToCanonicalString();
}
6 changes: 5 additions & 1 deletion Backend/Remora.Discord.Caching/Services/CacheService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ public class CacheService
/// </summary>
/// <param name="cacheProvider">The cache provider.</param>
/// <param name="cacheSettings">The cache settings.</param>
public CacheService(ICacheProvider cacheProvider, ImmutableCacheSettings cacheSettings)
public CacheService
(
ICacheProvider cacheProvider,
ImmutableCacheSettings cacheSettings
)
{
_cacheProvider = cacheProvider;
_cacheSettings = cacheSettings;
Expand Down
42 changes: 35 additions & 7 deletions Backend/Remora.Discord.Rest/Caching/MemoryCacheProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//

using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
Expand All @@ -37,14 +39,33 @@ namespace Remora.Discord.Rest.Caching;
public class MemoryCacheProvider : ICacheProvider
{
private readonly IMemoryCache _memoryCache;
private readonly string? _tokenHash;

/// <summary>
/// Initializes a new instance of the <see cref="MemoryCacheProvider"/> class.
/// </summary>
/// <param name="memoryCache">The memory cache.</param>
public MemoryCacheProvider(IMemoryCache memoryCache)
/// <param name="tokenStore">The token store, if one is available.</param>
public MemoryCacheProvider(IMemoryCache memoryCache, ITokenStore? tokenStore = null)
{
_memoryCache = memoryCache;

if (tokenStore is null)
{
_tokenHash = null;
return;
}

using var hasher = SHA256.Create();
var hashBuilder = new StringBuilder(64);
var hash = hasher.ComputeHash(Encoding.UTF8.GetBytes(tokenStore.Token));

foreach (var value in hash)
{
hashBuilder.Append(value.ToString("x2"));
}

_tokenHash = hashBuilder.ToString();
}

/// <inheritdoc cref="ICacheProvider.CacheAsync{TInstance}" />
Expand All @@ -57,7 +78,7 @@ public ValueTask CacheAsync<TInstance>
)
where TInstance : class
{
_memoryCache.Set(key, instance, options);
_memoryCache.Set(CreateTokenScopedKey(key), instance, options);

return default;
}
Expand All @@ -66,7 +87,7 @@ public ValueTask CacheAsync<TInstance>
public ValueTask<Result<TInstance>> RetrieveAsync<TInstance>(CacheKey key, CancellationToken ct = default)
where TInstance : class
{
if (_memoryCache.TryGetValue<TInstance>(key, out var instance))
if (_memoryCache.TryGetValue<TInstance>(CreateTokenScopedKey(key), out var instance))
{
return new(instance);
}
Expand All @@ -77,25 +98,32 @@ public ValueTask<Result<TInstance>> RetrieveAsync<TInstance>(CacheKey key, Cance
/// <inheritdoc cref="ICacheProvider.EvictAsync" />
public ValueTask<Result> EvictAsync(CacheKey key, CancellationToken ct = default)
{
if (!_memoryCache.TryGetValue(key, out _))
if (!_memoryCache.TryGetValue(CreateTokenScopedKey(key), out _))
{
return new(new NotFoundError($"The key \"{key}\" did not contain a value in cache."));
}

_memoryCache.Remove(key);
_memoryCache.Remove(CreateTokenScopedKey(key));
return new(Result.FromSuccess());
}

/// <inheritdoc cref="ICacheProvider.EvictAsync{TInstance}"/>
public ValueTask<Result<TInstance>> EvictAsync<TInstance>(CacheKey key, CancellationToken ct = default)
where TInstance : class
{
if (!_memoryCache.TryGetValue(key, out TInstance? existingValue))
if (!_memoryCache.TryGetValue(CreateTokenScopedKey(key), out TInstance? existingValue))
{
return new(new NotFoundError($"The key \"{key}\" did not contain a value in cache."));
}

_memoryCache.Remove(key);
_memoryCache.Remove(CreateTokenScopedKey(key));
return new(existingValue);
}

/// <summary>
/// Creates a cache key scoped to a specific token.
/// </summary>
/// <param name="key">The key.</param>
/// <returns>The scoped key.</returns>
private object CreateTokenScopedKey(CacheKey key) => (_tokenHash, key);
}