diff --git a/Test/DurableTask.Core.Tests/PropertiesMiddlewareTests.cs b/Test/DurableTask.Core.Tests/PropertiesMiddlewareTests.cs new file mode 100644 index 000000000..a7b648250 --- /dev/null +++ b/Test/DurableTask.Core.Tests/PropertiesMiddlewareTests.cs @@ -0,0 +1,147 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- +#nullable enable +namespace DurableTask.Core.Tests +{ + using System; + using System.Diagnostics; + using System.Threading.Tasks; + using DurableTask.Emulator; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class PropertiesMiddlewareTests + { + private const string PropertyKey = "Test"; + private const string PropertyValue = "Value"; + + TaskHubWorker worker = null!; + TaskHubClient client = null!; + + [TestInitialize] + public async Task Initialize() + { + var service = new LocalOrchestrationService(); + this.worker = new TaskHubWorker(service); + + await this.worker + .AddTaskOrchestrations(typeof(NoActivities), typeof(RunActivityOrchestrator)) + .AddTaskActivities(typeof(ReturnPropertyActivity)) + .StartAsync(); + + this.client = new TaskHubClient(service); + } + + [TestCleanup] + public async Task TestCleanup() + { + await this.worker.StopAsync(true); + } + + private sealed class NoActivities : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, string input) + { + return Task.FromResult(context.GetProperty(PropertyKey)!); + } + } + + private sealed class ReturnPropertyActivity : TaskActivity + { + protected override string Execute(TaskContext context, string input) + { + return context.GetProperty(PropertyKey)!; + } + } + + private sealed class RunActivityOrchestrator : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, string input) + { + return context.ScheduleTask(typeof(ReturnPropertyActivity)); + } + } + + [TestMethod] + public async Task OrchestrationGetsProperties() + { + this.worker.AddOrchestrationDispatcherMiddleware((context, next) => + { + context.SetProperty(PropertyKey, PropertyValue); + + return next(); + }); + + OrchestrationInstance instance = await this.client.CreateOrchestrationInstanceAsync(typeof(NoActivities), null); + + TimeSpan timeout = TimeSpan.FromSeconds(Debugger.IsAttached ? 1000 : 10); + var state = await this.client.WaitForOrchestrationAsync(instance, timeout); + + Assert.AreEqual($"\"{PropertyValue}\"", state.Output); + } + + [TestMethod] + public async Task OrchestrationDoesNotGetPropertiesFromActivityMiddleware() + { + this.worker.AddActivityDispatcherMiddleware((context, next) => + { + context.SetProperty(PropertyKey, PropertyValue); + + return next(); + }); + + OrchestrationInstance instance = await this.client.CreateOrchestrationInstanceAsync(typeof(NoActivities), null); + + TimeSpan timeout = TimeSpan.FromSeconds(Debugger.IsAttached ? 1000 : 10); + var state = await this.client.WaitForOrchestrationAsync(instance, timeout); + + Assert.IsNull(state.Output); + } + + [TestMethod] + public async Task ActivityGetsProperties() + { + this.worker.AddActivityDispatcherMiddleware((context, next) => + { + context.SetProperty(PropertyKey, PropertyValue); + + return next(); + }); + + OrchestrationInstance instance = await this.client.CreateOrchestrationInstanceAsync(typeof(RunActivityOrchestrator), null); + + TimeSpan timeout = TimeSpan.FromSeconds(Debugger.IsAttached ? 1000 : 10); + var state = await this.client.WaitForOrchestrationAsync(instance, timeout); + + Assert.AreEqual($"\"{PropertyValue}\"", state.Output); + } + + [TestMethod] + public async Task ActivityDoesNotGetPropertiesFromOrchestratorMiddleware() + { + this.worker.AddOrchestrationDispatcherMiddleware((context, next) => + { + context.SetProperty(PropertyKey, PropertyValue); + + return next(); + }); + + OrchestrationInstance instance = await this.client.CreateOrchestrationInstanceAsync(typeof(RunActivityOrchestrator), null); + + TimeSpan timeout = TimeSpan.FromSeconds(Debugger.IsAttached ? 1000 : 10); + var state = await this.client.WaitForOrchestrationAsync(instance, timeout); + + Assert.IsNull(state.Output); + } + } +} diff --git a/Test/DurableTask.Core.Tests/RetryInterceptorTests.cs b/Test/DurableTask.Core.Tests/RetryInterceptorTests.cs index a8a711724..1a6649817 100644 --- a/Test/DurableTask.Core.Tests/RetryInterceptorTests.cs +++ b/Test/DurableTask.Core.Tests/RetryInterceptorTests.cs @@ -89,7 +89,7 @@ sealed class MockOrchestrationContext : TaskOrchestrationContext readonly List delays = new List(); public MockOrchestrationContext(OrchestrationInstance orchestrationInstance, TaskScheduler taskScheduler) - : base(orchestrationInstance, taskScheduler) + : base(orchestrationInstance, new PropertiesDictionary(), taskScheduler) { CurrentUtcDateTime = DateTime.UtcNow; } diff --git a/src/DurableTask.Core/ContextPropertiesExtensions.cs b/src/DurableTask.Core/ContextPropertiesExtensions.cs new file mode 100644 index 000000000..f2d57da32 --- /dev/null +++ b/src/DurableTask.Core/ContextPropertiesExtensions.cs @@ -0,0 +1,80 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +#nullable enable + +using System; +using System.Collections.Generic; + +namespace DurableTask.Core +{ + /// + /// Extension methods that help get properties from . + /// + public static class ContextPropertiesExtensions + { + /// + /// Sets a property value to the context using the full name of the type as the key. + /// + /// The type of the property. + /// Properties to set property for. + /// The value of the property. + public static void SetProperty(this IContextProperties properties, T? value) => properties.SetProperty(typeof(T).FullName, value); + + /// + /// Sets a named property value to the context. + /// + /// The type of the property. + /// Properties to set property for. + /// The name of the property. + /// The value of the property. + public static void SetProperty(this IContextProperties properties, string key, T? value) + { + if (value is null) + { + properties.Properties.Remove(key); + } + else + { + properties.Properties[key] = value; + } + } + + /// + /// Gets a property value from the context using the full name of . + /// + /// The type of the property. + /// Properties to get property from. + /// The value of the property or default(T) if the property is not defined. + public static T? GetProperty(this IContextProperties properties) => properties.GetProperty(typeof(T).FullName); + + internal static T GetRequiredProperty(this IContextProperties properties) + => properties.GetProperty() ?? throw new InvalidOperationException($"Could not find property for {typeof(T).FullName}"); + + /// + /// Gets a named property value from the context. + /// + /// + /// Properties to get property from. + /// The name of the property value. + /// The value of the property or default(T) if the property is not defined. + public static T? GetProperty(this IContextProperties properties, string key) => properties.Properties.TryGetValue(key, out object value) ? (T)value : default; + + /// + /// Gets the tags from the current properties. + /// + /// + /// + public static IDictionary GetTags(this IContextProperties properties) => properties.GetRequiredProperty().OrchestrationTags; + } +} \ No newline at end of file diff --git a/src/DurableTask.Core/IContextProperties.cs b/src/DurableTask.Core/IContextProperties.cs new file mode 100644 index 000000000..69ee6c2f5 --- /dev/null +++ b/src/DurableTask.Core/IContextProperties.cs @@ -0,0 +1,30 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System.Collections.Generic; + +#nullable enable + +namespace DurableTask.Core +{ + /// + /// Collection of properties for context objects to store arbitrary state. + /// + public interface IContextProperties + { + /// + /// Gets the properties of the current instance + /// + IDictionary Properties { get; } + } +} \ No newline at end of file diff --git a/src/DurableTask.Core/Middleware/DispatchMiddlewareContext.cs b/src/DurableTask.Core/Middleware/DispatchMiddlewareContext.cs index 256d3e7dc..399915175 100644 --- a/src/DurableTask.Core/Middleware/DispatchMiddlewareContext.cs +++ b/src/DurableTask.Core/Middleware/DispatchMiddlewareContext.cs @@ -13,13 +13,12 @@ namespace DurableTask.Core.Middleware { - using System; using System.Collections.Generic; /// /// Context data that can be used to share data between middleware. /// - public class DispatchMiddlewareContext + public class DispatchMiddlewareContext : IContextProperties { /// /// Sets a property value to the context using the full name of the type as the key. @@ -28,7 +27,7 @@ public class DispatchMiddlewareContext /// The value of the property. public void SetProperty(T value) { - SetProperty(typeof(T).FullName, value); + ContextPropertiesExtensions.SetProperty(this, value); } /// @@ -39,7 +38,7 @@ public void SetProperty(T value) /// The value of the property. public void SetProperty(string key, T value) { - Properties[key] = value; + ContextPropertiesExtensions.SetProperty(this, key, value); } /// @@ -49,7 +48,7 @@ public void SetProperty(string key, T value) /// The value of the property or default(T) if the property is not defined. public T GetProperty() { - return GetProperty(typeof(T).FullName); + return ContextPropertiesExtensions.GetProperty(this); } /// @@ -60,12 +59,12 @@ public T GetProperty() /// The value of the property or default(T) if the property is not defined. public T GetProperty(string key) { - return Properties.TryGetValue(key, out object value) ? (T)value : default(T); + return ContextPropertiesExtensions.GetProperty(this, key); } /// /// Gets a key/value collection that can be used to share data between middleware. /// - public IDictionary Properties { get; } = new Dictionary(StringComparer.Ordinal); + public IDictionary Properties { get; } = new PropertiesDictionary(); } } diff --git a/src/DurableTask.Core/OrchestrationContext.cs b/src/DurableTask.Core/OrchestrationContext.cs index 52238bbc2..9fe4bb915 100644 --- a/src/DurableTask.Core/OrchestrationContext.cs +++ b/src/DurableTask.Core/OrchestrationContext.cs @@ -23,13 +23,16 @@ namespace DurableTask.Core /// /// Context for an orchestration containing the instance, replay status, orchestration methods and proxy methods /// - public abstract class OrchestrationContext + public abstract class OrchestrationContext : IContextProperties { /// /// Used in generating proxy interfaces and classes. /// private static readonly ProxyGenerator ProxyGenerator = new ProxyGenerator(); + /// + public virtual IDictionary Properties { get; } = new Dictionary(); + /// /// Thread-static variable used to signal whether the calling thread is the orchestrator thread. /// The primary use case is for detecting illegal async usage in orchestration code. @@ -40,12 +43,20 @@ public abstract class OrchestrationContext /// /// JsonDataConverter for message serialization settings /// - public JsonDataConverter MessageDataConverter { get; set; } + public JsonDataConverter MessageDataConverter + { + get => this.GetProperty(nameof(MessageDataConverter)) ?? this.GetProperty() ?? JsonDataConverter.Default; + set => this.SetProperty(nameof(MessageDataConverter), value); + } /// /// JsonDataConverter for error serialization settings /// - public JsonDataConverter ErrorDataConverter { get; set; } + public JsonDataConverter ErrorDataConverter + { + get => this.GetProperty(nameof(ErrorDataConverter)) ?? this.GetProperty() ?? JsonDataConverter.Default; + set => this.SetProperty(nameof(ErrorDataConverter), value); + } /// /// Instance of the currently executing orchestration diff --git a/src/DurableTask.Core/PropertiesDictionary.cs b/src/DurableTask.Core/PropertiesDictionary.cs new file mode 100644 index 000000000..1b13792b6 --- /dev/null +++ b/src/DurableTask.Core/PropertiesDictionary.cs @@ -0,0 +1,28 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core +{ + using System; + using System.Collections.Generic; + + internal sealed class PropertiesDictionary : Dictionary, IContextProperties + { + public PropertiesDictionary() + : base(StringComparer.Ordinal) + { + } + + IDictionary IContextProperties.Properties => this; + } +} diff --git a/src/DurableTask.Core/ReflectionBasedTaskActivity.cs b/src/DurableTask.Core/ReflectionBasedTaskActivity.cs index b935f8c1c..474a9046e 100644 --- a/src/DurableTask.Core/ReflectionBasedTaskActivity.cs +++ b/src/DurableTask.Core/ReflectionBasedTaskActivity.cs @@ -38,7 +38,6 @@ public class ReflectionBasedTaskActivity : TaskActivity /// The Reflection.methodInfo for invoking the method on the activity object public ReflectionBasedTaskActivity(object activityObject, MethodInfo methodInfo) { - DataConverter = JsonDataConverter.Default; ActivityObject = activityObject; MethodInfo = methodInfo; genericArguments = methodInfo.GetGenericArguments(); @@ -47,12 +46,17 @@ public ReflectionBasedTaskActivity(object activityObject, MethodInfo methodInfo) /// /// The DataConverter to use for input and output serialization/deserialization /// - public DataConverter DataConverter + public DataConverter DataConverter2 { - get => dataConverter; + get => dataConverter ?? JsonDataConverter.Default; set => dataConverter = value ?? throw new ArgumentNullException(nameof(value)); } + private DataConverter GetConverter(TaskContext context) + { + return dataConverter ?? context.GetProperty(); + } + /// /// The activity object to invoke methods on /// @@ -81,6 +85,7 @@ public override string Run(TaskContext context, string input) /// Serialized output from the execution public override async Task RunAsync(TaskContext context, string input) { + var converter = GetConverter(context); var jArray = Utils.ConvertToJArray(input); int parameterCount = jArray.Count - this.genericArguments.Length; @@ -93,7 +98,7 @@ public override async Task RunAsync(TaskContext context, string input) } Type[] genericTypeArguments = this.GetGenericTypeArguments(jArray); - object[] inputParameters = this.GetInputParameters(jArray, parameterCount, methodParameters, genericTypeArguments); + object[] inputParameters = this.GetInputParameters(jArray, parameterCount, methodParameters, genericTypeArguments, converter); string serializedReturn = string.Empty; Exception exception = null; @@ -108,7 +113,7 @@ public override async Task RunAsync(TaskContext context, string input) Type returnType = Utils.GetGenericReturnType(this.MethodInfo, genericTypeArguments); PropertyInfo resultProperty = typeof(Task<>).MakeGenericType(returnType).GetProperty("Result"); - serializedReturn = this.DataConverter.Serialize(resultProperty.GetValue(invocationTask)); + serializedReturn = converter.Serialize(resultProperty.GetValue(invocationTask)); } else { @@ -118,7 +123,7 @@ public override async Task RunAsync(TaskContext context, string input) } else { - serializedReturn = DataConverter.Serialize(invocationResult); + serializedReturn = converter.Serialize(invocationResult); } } catch (TargetInvocationException e) @@ -136,7 +141,7 @@ public override async Task RunAsync(TaskContext context, string input) FailureDetails failureDetails = null; if (context.ErrorPropagationMode == ErrorPropagationMode.SerializeExceptions) { - details = Utils.SerializeCause(exception, DataConverter); + details = Utils.SerializeCause(exception, converter); } else { @@ -194,7 +199,7 @@ private Type[] GetGenericTypeArguments(JArray jArray) return genericParameters.ToArray(); } - private object[] GetInputParameters(JArray jArray, int parameterCount, ParameterInfo[] methodParameters, Type[] genericArguments) + private object[] GetInputParameters(JArray jArray, int parameterCount, ParameterInfo[] methodParameters, Type[] genericArguments, DataConverter converter) { var inputParameters = new object[methodParameters.Length]; for (var i = 0; i < methodParameters.Length; i++) @@ -214,7 +219,7 @@ private object[] GetInputParameters(JArray jArray, int parameterCount, Parameter else { string serializedValue = jToken.ToString(); - inputParameters[i] = this.DataConverter.Deserialize(serializedValue, parameterType); + inputParameters[i] = converter.Deserialize(serializedValue, parameterType); } } else diff --git a/src/DurableTask.Core/TaskActivity.cs b/src/DurableTask.Core/TaskActivity.cs index 5cf79c09d..829a5f7a8 100644 --- a/src/DurableTask.Core/TaskActivity.cs +++ b/src/DurableTask.Core/TaskActivity.cs @@ -54,12 +54,13 @@ public virtual Task RunAsync(TaskContext context, string input) /// Output type of the activity public abstract class AsyncTaskActivity : TaskActivity { + private DataConverter dataConverter; + /// /// Creates a new AsyncTaskActivity with the default DataConverter /// protected AsyncTaskActivity() { - DataConverter = JsonDataConverter.Default; } /// @@ -68,13 +69,22 @@ protected AsyncTaskActivity() /// protected AsyncTaskActivity(DataConverter dataConverter) { - DataConverter = dataConverter ?? JsonDataConverter.Default; + this.dataConverter = dataConverter; } /// /// The DataConverter to use for input and output serialization/deserialization /// - public DataConverter DataConverter { get; protected set; } + public DataConverter DataConverter + { + get => dataConverter ?? JsonDataConverter.Default; + set => dataConverter = value; + } + + private DataConverter GetDataConverter(TaskContext context) + { + return DataConverter ?? context.GetProperty(); + } /// /// Synchronous execute method, blocked for AsyncTaskActivity @@ -102,6 +112,7 @@ public override string Run(TaskContext context, string input) /// Serialized output from the execution public override async Task RunAsync(TaskContext context, string input) { + var converter = GetDataConverter(context); TInput parameter = default(TInput); var jArray = Utils.ConvertToJArray(input); @@ -112,7 +123,7 @@ public override async Task RunAsync(TaskContext context, string input) throw new TaskFailureException( "TaskActivity implementation cannot be invoked due to more than expected input parameters. Signature mismatch."); } - + if (parameterCount == 1) { JToken jToken = jArray[0]; @@ -123,7 +134,7 @@ public override async Task RunAsync(TaskContext context, string input) else { string serializedValue = jToken.ToString(); - parameter = DataConverter.Deserialize(serializedValue); + parameter = converter.Deserialize(serializedValue); } } @@ -138,7 +149,7 @@ public override async Task RunAsync(TaskContext context, string input) FailureDetails failureDetails = null; if (context.ErrorPropagationMode == ErrorPropagationMode.SerializeExceptions) { - details = Utils.SerializeCause(e, DataConverter); + details = Utils.SerializeCause(e, converter); } else { @@ -149,7 +160,7 @@ public override async Task RunAsync(TaskContext context, string input) .WithFailureDetails(failureDetails); } - string serializedResult = DataConverter.Serialize(result); + string serializedResult = converter.Serialize(result); return serializedResult; } } diff --git a/src/DurableTask.Core/TaskActivityDispatcher.cs b/src/DurableTask.Core/TaskActivityDispatcher.cs index 1c0b8464f..b3b1db60d 100644 --- a/src/DurableTask.Core/TaskActivityDispatcher.cs +++ b/src/DurableTask.Core/TaskActivityDispatcher.cs @@ -22,6 +22,7 @@ namespace DurableTask.Core using DurableTask.Core.History; using DurableTask.Core.Logging; using DurableTask.Core.Middleware; + using DurableTask.Core.Serializing; using DurableTask.Core.Tracing; /// @@ -33,6 +34,7 @@ public sealed class TaskActivityDispatcher readonly WorkItemDispatcher dispatcher; readonly IOrchestrationService orchestrationService; readonly DispatchMiddlewarePipeline dispatchPipeline; + readonly JsonDataConverter dataConverter; readonly LogHelper logHelper; readonly ErrorPropagationMode errorPropagationMode; @@ -40,12 +42,14 @@ internal TaskActivityDispatcher( IOrchestrationService orchestrationService, INameVersionObjectManager objectManager, DispatchMiddlewarePipeline dispatchPipeline, + JsonDataConverter dataConverter, LogHelper logHelper, ErrorPropagationMode errorPropagationMode) { this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService)); this.objectManager = objectManager ?? throw new ArgumentNullException(nameof(objectManager)); this.dispatchPipeline = dispatchPipeline ?? throw new ArgumentNullException(nameof(dispatchPipeline)); + this.dataConverter = dataConverter; this.logHelper = logHelper; this.errorPropagationMode = errorPropagationMode; @@ -157,6 +161,8 @@ async Task OnProcessWorkItemAsync(TaskActivityWorkItem workItem) dispatchContext.SetProperty(taskMessage.OrchestrationInstance); dispatchContext.SetProperty(taskActivity); dispatchContext.SetProperty(scheduledEvent); + dispatchContext.SetProperty(dataConverter); + dispatchContext.SetProperty(dataConverter); // In transitionary phase (activity queued from old code, accessed in new code) context can be null. if (taskMessage.OrchestrationExecutionContext != null) @@ -174,7 +180,7 @@ async Task OnProcessWorkItemAsync(TaskActivityWorkItem workItem) ActivityExecutionResult? result; try { - await this.dispatchPipeline.RunAsync(dispatchContext, async _ => + await this.dispatchPipeline.RunAsync(dispatchContext, async dispatchContext => { if (taskActivity == null) { @@ -185,7 +191,7 @@ await this.dispatchPipeline.RunAsync(dispatchContext, async _ => throw new TypeMissingException($"TaskActivity {scheduledEvent.Name} version {scheduledEvent.Version} was not found"); } - var context = new TaskContext(taskMessage.OrchestrationInstance); + var context = new TaskContext(taskMessage.OrchestrationInstance, dispatchContext.Properties); context.ErrorPropagationMode = this.errorPropagationMode; HistoryEvent? responseEvent; diff --git a/src/DurableTask.Core/TaskContext.cs b/src/DurableTask.Core/TaskContext.cs index f959ef12c..2fd09016f 100644 --- a/src/DurableTask.Core/TaskContext.cs +++ b/src/DurableTask.Core/TaskContext.cs @@ -11,20 +11,34 @@ // limitations under the License. // ---------------------------------------------------------------------------------- +using System.Collections.Generic; + namespace DurableTask.Core { /// /// Task context /// - public class TaskContext + public class TaskContext : IContextProperties { /// - /// Creates a new TaskContext with the supplied OrchestrationInstance + /// Creates a new with the supplied /// /// public TaskContext(OrchestrationInstance orchestrationInstance) { OrchestrationInstance = orchestrationInstance; + Properties = new PropertiesDictionary(); + } + + /// + /// Creates a new with the supplied and properties + /// + /// + /// + public TaskContext(OrchestrationInstance orchestrationInstance, IDictionary properties) + { + OrchestrationInstance = orchestrationInstance; + Properties = properties; } /// @@ -32,6 +46,11 @@ public TaskContext(OrchestrationInstance orchestrationInstance) /// public OrchestrationInstance OrchestrationInstance { get; private set; } + /// + /// Gets the properties of the current instance + /// + public IDictionary Properties { get; } + /// /// Gets or sets a value indicating how to propagate unhandled exception metadata. /// diff --git a/src/DurableTask.Core/TaskHubWorker.cs b/src/DurableTask.Core/TaskHubWorker.cs index 1b93bd62e..c98cd7e06 100644 --- a/src/DurableTask.Core/TaskHubWorker.cs +++ b/src/DurableTask.Core/TaskHubWorker.cs @@ -24,6 +24,7 @@ namespace DurableTask.Core using DurableTask.Core.Exceptions; using DurableTask.Core.Logging; using DurableTask.Core.Middleware; + using DurableTask.Core.Serializing; using Microsoft.Extensions.Logging; /// @@ -178,17 +179,20 @@ public async Task StartAsync() this.logHelper.TaskHubWorkerStarting(); var sw = Stopwatch.StartNew(); + var defaultSerializer = JsonDataConverter.Default; // TODO: Can we get this from the TaskHubClient? this.orchestrationDispatcher = new TaskOrchestrationDispatcher( this.orchestrationService, this.orchestrationManager, this.orchestrationDispatchPipeline, + defaultSerializer, this.logHelper, this.ErrorPropagationMode); this.activityDispatcher = new TaskActivityDispatcher( this.orchestrationService, this.activityManager, this.activityDispatchPipeline, + defaultSerializer, this.logHelper, this.ErrorPropagationMode); diff --git a/src/DurableTask.Core/TaskOrchestration.cs b/src/DurableTask.Core/TaskOrchestration.cs index c198c0855..b7513aead 100644 --- a/src/DurableTask.Core/TaskOrchestration.cs +++ b/src/DurableTask.Core/TaskOrchestration.cs @@ -68,18 +68,28 @@ public abstract class TaskOrchestration : TaskOrchestrationOutput Type for GetStatus calls public abstract class TaskOrchestration : TaskOrchestration { + private DataConverter dataConverter; + /// /// Creates a new TaskOrchestration with the default DataConverter /// protected TaskOrchestration() { - DataConverter = JsonDataConverter.Default; } /// /// The DataConverter to use for input and output serialization/deserialization /// - public DataConverter DataConverter { get; protected set; } + public DataConverter DataConverter + { + get => dataConverter ?? JsonDataConverter.Default; + protected set => dataConverter = value; + } + + private DataConverter GetConverter(OrchestrationContext context) + { + return dataConverter ?? context.GetProperty(); + } /// /// Method for executing an orchestration based on the context and serialized input @@ -89,7 +99,9 @@ protected TaskOrchestration() /// Serialized output from the execution public override async Task Execute(OrchestrationContext context, string input) { - var parameter = DataConverter.Deserialize(input); + var converter = GetConverter(context); + + var parameter = converter.Deserialize(input); TResult result; try { @@ -101,7 +113,7 @@ public override async Task Execute(OrchestrationContext context, string FailureDetails failureDetails = null; if (context.ErrorPropagationMode == ErrorPropagationMode.SerializeExceptions) { - details = Utils.SerializeCause(e, DataConverter); + details = Utils.SerializeCause(e, converter); } else { @@ -114,7 +126,7 @@ public override async Task Execute(OrchestrationContext context, string }; } - return DataConverter.Serialize(result); + return converter.Serialize(result); } /// @@ -125,7 +137,7 @@ public override async Task Execute(OrchestrationContext context, string /// The serialized input public override void RaiseEvent(OrchestrationContext context, string name, string input) { - var parameter = DataConverter.Deserialize(input); + var parameter = GetConverter(context).Deserialize(input); OnEvent(context, name, parameter); } diff --git a/src/DurableTask.Core/TaskOrchestrationContext.cs b/src/DurableTask.Core/TaskOrchestrationContext.cs index a9831ff47..dc86ea375 100644 --- a/src/DurableTask.Core/TaskOrchestrationContext.cs +++ b/src/DurableTask.Core/TaskOrchestrationContext.cs @@ -23,7 +23,6 @@ namespace DurableTask.Core using DurableTask.Core.Common; using DurableTask.Core.Exceptions; using DurableTask.Core.History; - using DurableTask.Core.Serializing; using DurableTask.Core.Tracing; internal class TaskOrchestrationContext : OrchestrationContext @@ -44,8 +43,11 @@ public void AddEventToNextIteration(HistoryEvent he) continueAsNew.CarryoverEvents.Add(he); } + public override IDictionary Properties { get; } + public TaskOrchestrationContext( OrchestrationInstance orchestrationInstance, + IContextProperties properties, TaskScheduler taskScheduler, ErrorPropagationMode errorPropagationMode = ErrorPropagationMode.SerializeExceptions) { @@ -54,12 +56,11 @@ public TaskOrchestrationContext( this.openTasks = new Dictionary(); this.orchestratorActionsMap = new SortedDictionary(); this.idCounter = 0; - this.MessageDataConverter = JsonDataConverter.Default; - this.ErrorDataConverter = JsonDataConverter.Default; OrchestrationInstance = orchestrationInstance; IsReplaying = false; ErrorPropagationMode = errorPropagationMode; this.eventsWhileSuspended = new Queue(); + Properties = properties.Properties; } public IEnumerable OrchestratorActions => this.orchestratorActionsMap.Values; diff --git a/src/DurableTask.Core/TaskOrchestrationDispatcher.cs b/src/DurableTask.Core/TaskOrchestrationDispatcher.cs index 6ebeaf89b..814eb325b 100644 --- a/src/DurableTask.Core/TaskOrchestrationDispatcher.cs +++ b/src/DurableTask.Core/TaskOrchestrationDispatcher.cs @@ -36,6 +36,7 @@ public class TaskOrchestrationDispatcher { static readonly Task CompletedTask = Task.FromResult(0); + readonly JsonDataConverter dataConverter; readonly INameVersionObjectManager objectManager; readonly IOrchestrationService orchestrationService; readonly WorkItemDispatcher dispatcher; @@ -48,12 +49,14 @@ internal TaskOrchestrationDispatcher( IOrchestrationService orchestrationService, INameVersionObjectManager objectManager, DispatchMiddlewarePipeline dispatchPipeline, + JsonDataConverter dataConverter, LogHelper logHelper, ErrorPropagationMode errorPropagationMode) { this.objectManager = objectManager ?? throw new ArgumentNullException(nameof(objectManager)); this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService)); this.dispatchPipeline = dispatchPipeline ?? throw new ArgumentNullException(nameof(dispatchPipeline)); + this.dataConverter = dataConverter; this.logHelper = logHelper ?? throw new ArgumentNullException(nameof(logHelper)); this.errorPropagationMode = errorPropagationMode; @@ -661,6 +664,8 @@ async Task ExecuteOrchestrationAsync(Orchestration TaskOrchestration? taskOrchestration = this.objectManager.GetObject(runtimeState.Name, runtimeState.Version!); var dispatchContext = new DispatchMiddlewareContext(); + dispatchContext.SetProperty(dataConverter); + dispatchContext.SetProperty(dataConverter); dispatchContext.SetProperty(runtimeState.OrchestrationInstance); dispatchContext.SetProperty(taskOrchestration); dispatchContext.SetProperty(runtimeState); @@ -669,7 +674,7 @@ async Task ExecuteOrchestrationAsync(Orchestration TaskOrchestrationExecutor? executor = null; - await this.dispatchPipeline.RunAsync(dispatchContext, _ => + await this.dispatchPipeline.RunAsync(dispatchContext, dispatchContext => { // Check to see if the custom middleware intercepted and substituted the orchestration execution // with its own execution behavior, providing us with the end results. If so, we can terminate @@ -680,7 +685,7 @@ await this.dispatchPipeline.RunAsync(dispatchContext, _ => return CompletedTask; } - if (taskOrchestration == null) + if (dispatchContext.GetProperty() is null) { throw TraceHelper.TraceExceptionInstance( TraceEventType.Error, @@ -690,8 +695,7 @@ await this.dispatchPipeline.RunAsync(dispatchContext, _ => } executor = new TaskOrchestrationExecutor( - runtimeState, - taskOrchestration, + dispatchContext, this.orchestrationService.EventBehaviourForContinueAsNew, this.errorPropagationMode); OrchestratorExecutionResult resultFromOrchestrator = executor.Execute(); diff --git a/src/DurableTask.Core/TaskOrchestrationExecutor.cs b/src/DurableTask.Core/TaskOrchestrationExecutor.cs index b0ca99976..3cff3ea49 100644 --- a/src/DurableTask.Core/TaskOrchestrationExecutor.cs +++ b/src/DurableTask.Core/TaskOrchestrationExecutor.cs @@ -24,6 +24,7 @@ namespace DurableTask.Core using DurableTask.Core.Common; using DurableTask.Core.Exceptions; using DurableTask.Core.History; + using DurableTask.Core.Middleware; /// /// Utility for executing task orchestrators. @@ -49,14 +50,39 @@ public TaskOrchestrationExecutor( TaskOrchestration taskOrchestration, BehaviorOnContinueAsNew eventBehaviourForContinueAsNew, ErrorPropagationMode errorPropagationMode = ErrorPropagationMode.SerializeExceptions) + : this(CreateProperties(orchestrationRuntimeState, taskOrchestration), eventBehaviourForContinueAsNew, errorPropagationMode) { + } + + private static IContextProperties CreateProperties(OrchestrationRuntimeState orchestrationRuntimeState, TaskOrchestration taskOrchestration) + { + var properties = new PropertiesDictionary(); + + properties.SetProperty(orchestrationRuntimeState); + properties.SetProperty(taskOrchestration); + + return properties; + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + public TaskOrchestrationExecutor( + IContextProperties properties, + BehaviorOnContinueAsNew eventBehaviourForContinueAsNew, + ErrorPropagationMode errorPropagationMode = ErrorPropagationMode.SerializeExceptions) + { + this.orchestrationRuntimeState = properties.GetRequiredProperty(); + this.taskOrchestration = properties.GetRequiredProperty(); this.decisionScheduler = new SynchronousTaskScheduler(); this.context = new TaskOrchestrationContext( orchestrationRuntimeState.OrchestrationInstance, + properties, this.decisionScheduler, errorPropagationMode); - this.orchestrationRuntimeState = orchestrationRuntimeState; - this.taskOrchestration = taskOrchestration; this.skipCarryOverEvents = eventBehaviourForContinueAsNew == BehaviorOnContinueAsNew.Ignore; } @@ -144,7 +170,7 @@ void ProcessEvents(IEnumerable events) // Let this exception propagate out to be handled by the dispatcher ExceptionDispatchInfo.Capture(exception).Throw(); } - + this.context.FailOrchestration(exception); } else diff --git a/test/DurableTask.Core.Tests/DispatcherMiddlewareTests.cs b/test/DurableTask.Core.Tests/DispatcherMiddlewareTests.cs index ad89efc92..5f44ca599 100644 --- a/test/DurableTask.Core.Tests/DispatcherMiddlewareTests.cs +++ b/test/DurableTask.Core.Tests/DispatcherMiddlewareTests.cs @@ -250,7 +250,7 @@ public void EnsureOrchestrationExecutionContextSupportsDataContractSerialization [TestMethod] public async Task EnsureSubOrchestrationDispatcherMiddlewareHasAccessToRuntimeState() { - ConcurrentBag capturedContexts = new ConcurrentBag(); + ConcurrentBag capturedContexts = new ConcurrentBag(); for (var i = 0; i < 10; i++) {