diff --git a/src/Kralizek.Lambda.Template.Sns/SnsEventHandler.cs b/src/Kralizek.Lambda.Template.Sns/SnsEventHandler.cs index 937227b..32f58f8 100644 --- a/src/Kralizek.Lambda.Template.Sns/SnsEventHandler.cs +++ b/src/Kralizek.Lambda.Template.Sns/SnsEventHandler.cs @@ -25,12 +25,12 @@ public async Task HandleAsync(SNSEvent input, ILambdaContext context) { foreach (var record in input.Records) { - using (_serviceProvider.CreateScope()) + using (var scope = _serviceProvider.CreateScope()) { var message = record.Sns.Message; var notification = JsonSerializer.Deserialize(message); - - var handler = _serviceProvider.GetService>(); + + var handler = scope.ServiceProvider.GetService>(); if (handler == null) { diff --git a/src/Kralizek.Lambda.Template.Sqs/SqsEventHandler.cs b/src/Kralizek.Lambda.Template.Sqs/SqsEventHandler.cs index ac23537..d6f8daa 100644 --- a/src/Kralizek.Lambda.Template.Sqs/SqsEventHandler.cs +++ b/src/Kralizek.Lambda.Template.Sqs/SqsEventHandler.cs @@ -23,12 +23,12 @@ public async Task HandleAsync(SQSEvent input, ILambdaContext context) { foreach (var record in input.Records) { - using (_serviceProvider.CreateScope()) + using (var scope = _serviceProvider.CreateScope()) { var sqsMessage = record.Body; var message = JsonSerializer.Deserialize(sqsMessage); - var handler = _serviceProvider.GetService>(); + var handler = scope.ServiceProvider.GetService>(); if (handler == null) { diff --git a/tests/Tests.Lambda.Template/Sns/SnsEventHandlerDisposalTests.cs b/tests/Tests.Lambda.Template/Sns/SnsEventHandlerDisposalTests.cs new file mode 100644 index 0000000..5934279 --- /dev/null +++ b/tests/Tests.Lambda.Template/Sns/SnsEventHandlerDisposalTests.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Amazon.Lambda.Core; +using Amazon.Lambda.SNSEvents; +using Amazon.Lambda.SQSEvents; +using Amazon.Lambda.TestUtilities; +using Kralizek.Lambda; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging.Abstractions; +using NUnit.Framework; + +namespace Tests.Lambda.Sns +{ + public class SnsEventHandlerDisposalTests + { + [Test] + public async Task EventHandler_Should_Use_Scoped_Object_In_ForEach_Loop() + { + var snsEvent = new SNSEvent + { + Records = new List + { + new SNSEvent.SNSRecord + { + Sns = new SNSEvent.SNSMessage + { + Message = "{}" + } + }, + new SNSEvent.SNSRecord + { + Sns = new SNSEvent.SNSMessage + { + Message = "{}" + } + } + } + }; + + var dependency = new DisposableDependency(); + + var services = new ServiceCollection(); + + services.AddScoped(_ => dependency); + + var tcs = new TaskCompletionSource(); + services.AddTransient, SqsEventHandler>(); + + services.AddTransient, + TestNotificationScopedHandler>(provider => + new TestNotificationScopedHandler(provider.GetRequiredService(), tcs)); + + var sp = services.BuildServiceProvider(); + var snsEventHandler = new SnsEventHandler(sp, new NullLoggerFactory()); + + var task = snsEventHandler.HandleAsync(snsEvent, new TestLambdaContext()); + + Assert.That(dependency.Disposed, Is.False, "Dependency should not be disposed"); + Assert.That(task.IsCompleted, Is.False, "The task should not be completed"); + + tcs.SetResult(new TestNotification()); + + await task; + + Assert.That(dependency.Disposed, Is.True, "Dependency should be disposed"); + Assert.That(task.IsCompleted, Is.True, "The task should be completed"); + } + + private class DisposableDependency : IDisposable + { + public bool Disposed { get; private set; } + public void Dispose() => Disposed = true; + } + + private class TestNotificationScopedHandler: INotificationHandler + { + private readonly DisposableDependency _dependency; + private readonly TaskCompletionSource _tcs; + + public TestNotificationScopedHandler(DisposableDependency dependency, TaskCompletionSource tcs) + { + _dependency = dependency; + _tcs = tcs; + } + + public Task HandleAsync(TestNotification message, ILambdaContext context) => _tcs.Task; + } + } +} diff --git a/tests/Tests.Lambda.Template/Sns/SnsEventHandlerTests.cs b/tests/Tests.Lambda.Template/Sns/SnsEventHandlerTests.cs index 624e8a5..e89044b 100644 --- a/tests/Tests.Lambda.Template/Sns/SnsEventHandlerTests.cs +++ b/tests/Tests.Lambda.Template/Sns/SnsEventHandlerTests.cs @@ -20,6 +20,8 @@ public class SnsEventHandlerTests private Mock mockServiceScopeFactory; private Mock mockServiceProvider; private Mock mockLoggerFactory; + private Mock mockServiceScope; + [SetUp] public void Initialize() @@ -27,9 +29,11 @@ public void Initialize() mockNotificationHandler = new Mock>(); mockNotificationHandler.Setup(p => p.HandleAsync(It.IsAny(), It.IsAny())) .Returns(Task.CompletedTask); - + + mockServiceScope = new Mock(); + mockServiceScopeFactory = new Mock(); - mockServiceScopeFactory.Setup(p => p.CreateScope()).Returns(Mock.Of()); + mockServiceScopeFactory.Setup(p => p.CreateScope()).Returns(mockServiceScope.Object); mockServiceProvider = new Mock(); mockServiceProvider.Setup(p => p.GetService(typeof(INotificationHandler))) @@ -37,6 +41,9 @@ public void Initialize() mockServiceProvider.Setup(p => p.GetService(typeof(IServiceScopeFactory))) .Returns(mockServiceScopeFactory.Object); + mockServiceScope.Setup(p => p.ServiceProvider).Returns(mockServiceProvider.Object); + + mockLoggerFactory = new Mock(); mockLoggerFactory.Setup(p => p.CreateLogger(It.IsAny())) .Returns(Mock.Of()); @@ -174,6 +181,7 @@ public void HandleAsync_throws_InvalidOperation_if_NotificationHandler_is_not_re mockServiceProvider = new Mock(); mockServiceProvider.Setup(p => p.GetService(typeof(IServiceScopeFactory))).Returns(mockServiceScopeFactory.Object); + mockServiceScope.Setup(p => p.ServiceProvider).Returns(mockServiceProvider.Object); var sut = CreateSystemUnderTest(); diff --git a/tests/Tests.Lambda.Template/Sqs/SqsEventHandlerDisposalTests.cs b/tests/Tests.Lambda.Template/Sqs/SqsEventHandlerDisposalTests.cs new file mode 100644 index 0000000..89937f7 --- /dev/null +++ b/tests/Tests.Lambda.Template/Sqs/SqsEventHandlerDisposalTests.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Amazon.Lambda.Core; +using Amazon.Lambda.SQSEvents; +using Amazon.Lambda.TestUtilities; +using Kralizek.Lambda; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging.Abstractions; +using NUnit.Framework; + +namespace Tests.Lambda.Sqs +{ + public class SqsEventHandlerDisposalTests + { + [Test] + public async Task EventHandler_Should_Use_Scoped_Object_In_ForEach_Loop() + { + var sqsEvent = new SQSEvent + { + Records = new List + { + new SQSEvent.SQSMessage + { + Body = "{}" + }, + new SQSEvent.SQSMessage + { + Body = "{}" + }, + } + }; + + var dependency = new DisposableDependency(); + + var services = new ServiceCollection(); + + services.AddScoped(_ => dependency); + + var tcs = new TaskCompletionSource(); + services.AddTransient, SqsEventHandler>(); + + services.AddTransient, + TestMessageScopedHandler>(provider => + new TestMessageScopedHandler(provider.GetRequiredService(), tcs)); + + var sp = services.BuildServiceProvider(); + var sqsEventHandler = new SqsEventHandler(sp, new NullLoggerFactory()); + + var task = sqsEventHandler.HandleAsync(sqsEvent, new TestLambdaContext()); + + Assert.That(dependency.Disposed, Is.False, "Dependency should not be disposed"); + Assert.That(task.IsCompleted, Is.False, "The task should not be completed"); + + tcs.SetResult(new TestMessage()); + await task; + Assert.That(dependency.Disposed, Is.True, "Dependency should be disposed"); + Assert.That(task.IsCompleted, Is.True, "The task should be completed"); + + + } + + private class DisposableDependency : IDisposable + { + public bool Disposed { get; private set; } + public void Dispose() => Disposed = true; + } + + private class TestMessageScopedHandler : IMessageHandler + { + private readonly DisposableDependency _dependency; + private readonly TaskCompletionSource _tcs; + + public TestMessageScopedHandler(DisposableDependency dependency, TaskCompletionSource tcs) + { + _dependency = dependency; + _tcs = tcs; + } + + public Task HandleAsync(TestMessage message, ILambdaContext context) => _tcs.Task; + } + } +} diff --git a/tests/Tests.Lambda.Template/Sqs/SqsEventHandlerTests.cs b/tests/Tests.Lambda.Template/Sqs/SqsEventHandlerTests.cs index a16a5a0..b0aa0ad 100644 --- a/tests/Tests.Lambda.Template/Sqs/SqsEventHandlerTests.cs +++ b/tests/Tests.Lambda.Template/Sqs/SqsEventHandlerTests.cs @@ -7,6 +7,7 @@ using Kralizek.Lambda; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Moq; using NUnit.Framework; @@ -18,6 +19,7 @@ public class SqsEventHandlerTests private Mock mockServiceScopeFactory; private Mock mockServiceProvider; private Mock mockLoggerFactory; + private Mock mockServiceScope; [SetUp] @@ -26,8 +28,11 @@ public void Initialize() mockMessageHandler = new Mock>(); mockMessageHandler.Setup(p => p.HandleAsync(It.IsAny(), It.IsAny())).Returns(Task.CompletedTask); + mockServiceScope = new Mock(); + mockServiceScopeFactory = new Mock(); - mockServiceScopeFactory.Setup(p => p.CreateScope()).Returns(Mock.Of()); + + mockServiceScopeFactory.Setup(p => p.CreateScope()).Returns(mockServiceScope.Object); mockServiceProvider = new Mock(); mockServiceProvider.Setup(p => p.GetService(typeof(IMessageHandler))) @@ -35,6 +40,8 @@ public void Initialize() mockServiceProvider.Setup(p => p.GetService(typeof(IServiceScopeFactory))) .Returns(mockServiceScopeFactory.Object); + mockServiceScope.Setup(p => p.ServiceProvider).Returns(mockServiceProvider.Object); + mockLoggerFactory = new Mock(); mockLoggerFactory.Setup(p => p.CreateLogger(It.IsAny())) .Returns(Mock.Of()); @@ -148,6 +155,8 @@ public void HandleAsync_throws_InvalidOperation_if_NotificationHandler_is_not_re mockServiceProvider = new Mock(); mockServiceProvider.Setup(p => p.GetService(typeof(IServiceScopeFactory))).Returns(mockServiceScopeFactory.Object); + + mockServiceScope.Setup(p => p.ServiceProvider).Returns(mockServiceProvider.Object); var sut = CreateSystemUnderTest();