Skip to content

Commit

Permalink
fix(sdk-dotnet): Fix task worker connection manager to send hosts and…
Browse files Browse the repository at this point in the history
… ports available and refactor code
  • Loading branch information
KarlaCarvajal committed Dec 11, 2024
1 parent 453199b commit 5009c00
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 155 deletions.
2 changes: 1 addition & 1 deletion sdk-dotnet/Examples/BasicExample/MyWorker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace Examples.BasicExample
{
public class MyWorker
{
[LHTaskMethod("greet-dotnet")]
[LHTaskMethod("greet")]
public string Greeting(string name)
{
var message = $"Hello team, This is a Dotnet Worker";
Expand Down
4 changes: 2 additions & 2 deletions sdk-dotnet/Examples/BasicExample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ static void Main(string[] args)
{
var loggerFactory = _serviceProvider.GetRequiredService<ILoggerFactory>();
var config = GetLHConfig(args, loggerFactory);

MyWorker executable = new MyWorker();
var taskWorker = new LHTaskWorker<MyWorker>(executable, "greet-dotnet", config);
var taskWorker = new LHTaskWorker<MyWorker>(executable, "greet", config);

taskWorker.RegisterTaskDef();

Expand Down
18 changes: 9 additions & 9 deletions sdk-dotnet/LittleHorse.Sdk.Tests/Worker/VariableMappingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void VariableMapping_WithValidLHTypes_ShouldBeBuiltSuccessfully()
foreach (var type in testAllowedTypes)
{
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var result = new VariableMapping(taskDef, position, type, paramName);

Expand All @@ -53,7 +53,7 @@ public void VariableMapping_WithMismatchTypesInt_ShouldThrowException()
Type type1 = typeof(Int64);
Type type2 = typeof(string);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -67,7 +67,7 @@ public void VariableMapping_WithMismatchTypeDouble_ShouldThrowException()
Type type1 = typeof(double);
Type type2 = typeof(Int64);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -81,7 +81,7 @@ public void VariableMapping_WithMismatchTypeString_ShouldThrowException()
Type type1 = typeof(string);
Type type2 = typeof(double);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -95,7 +95,7 @@ public void VariableMapping_WithMismatchTypeBool_ShouldThrowException()
Type type1 = typeof(bool);
Type type2 = typeof(string);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -109,7 +109,7 @@ public void VariableMapping_WithMismatchTypeBytes_ShouldThrowException()
Type type1 = typeof(byte[]);
Type type2 = typeof(string);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand Down Expand Up @@ -302,11 +302,11 @@ public void VariableMapping_WithAssignJsonStringValue_ShouldReturnCustomObject()
Assert.Equal(expectedObject.Cars!.Count, actualObject.Cars!.Count);
}

private TaskDef getTaskDefForTest(VariableType type)
private TaskDef? getTaskDefForTest(VariableType type)
{
var inputVar = new VariableDef();
inputVar.Type = type;
TaskDef taskDef = new TaskDef();
TaskDef? taskDef = new TaskDef();
TaskDefId taskDefId = new TaskDefId();
taskDef.Id = taskDefId;
taskDef.InputVars.Add(inputVar);
Expand All @@ -317,7 +317,7 @@ private TaskDef getTaskDefForTest(VariableType type)
private VariableMapping getVariableMappingForTest(Type type, string paramName, int position)
{
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var variableMapping = new VariableMapping(taskDef, position, type, paramName);

Expand Down
2 changes: 1 addition & 1 deletion sdk-dotnet/LittleHorse.Sdk/Helper/LHHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace LittleHorse.Sdk.Helper
{
public static class LHHelper
{
public static WfRunId GetWFRunId(TaskRunSource taskRunSource)
public static WfRunId? GetWfRunId(TaskRunSource taskRunSource)
{
switch (taskRunSource.TaskRunSourceCase)
{
Expand Down
10 changes: 5 additions & 5 deletions sdk-dotnet/LittleHorse.Sdk/LHConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ private bool IsOAuth
}
}

public LittleHorseClient GetGrcpClientInstance()
public LittleHorseClient GetGrpcClientInstance()
{
return GetGrcpClientInstance(BootstrapHost, BootstrapPort);
return GetGrpcClientInstance(BootstrapHost, BootstrapPort);
}

public LittleHorseClient GetGrcpClientInstance(string host, int port)
public LittleHorseClient GetGrpcClientInstance(string host, int port)
{
string channelKey = BootstrapServer;
string channelKey = $"{BootstrapProtocol}://{host}:{port}";

if (_createdChannels.ContainsKey(channelKey))
{
Expand Down Expand Up @@ -208,7 +208,7 @@ public TaskDef GetTaskDef(string taskDefName)
{
try
{
var client = GetGrcpClientInstance();
var client = GetGrpcClientInstance();
var taskDefId = new TaskDefId()
{
Name = taskDefName
Expand Down
17 changes: 9 additions & 8 deletions sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ namespace LittleHorse.Sdk.Worker.Internal
{
public class LHServerConnection<T> : IDisposable
{
private LHServerConnectionManager<T> _connectionManager;
private LHHostInfo _hostInfo;
private readonly LHServerConnectionManager<T> _connectionManager;
private readonly LHHostInfo _hostInfo;
private bool _running;
private LittleHorseClient _client;
private readonly LittleHorseClient _client;
private AsyncDuplexStreamingCall<PollTaskRequest, PollTaskResponse> _call;
private ILogger? _logger;
private readonly ILogger? _logger;

public LHHostInfo HostInfo { get { return _hostInfo; } }

Expand All @@ -22,7 +22,7 @@ public LHServerConnection(LHServerConnectionManager<T> connectionManager, LHHost
_connectionManager = connectionManager;
_hostInfo = hostInfo;
_logger = LHLoggerFactoryProvider.GetLogger<LHServerConnection<T>>();
_client = _connectionManager.Config.GetGrcpClientInstance();
_client = _connectionManager.Config.GetGrpcClientInstance(hostInfo.Host, hostInfo.Port);
_call = _client.PollTask();
}

Expand All @@ -34,6 +34,7 @@ public void Connect()

private async Task RequestMoreWorkAsync()
{
_logger?.LogWarning($"Task worker in polling {_connectionManager.Config.WorkerId}");
var request = new PollTaskRequest
{
ClientId = _connectionManager.Config.WorkerId,
Expand All @@ -48,12 +49,12 @@ private async Task RequestMoreWorkAsync()
if (taskToDo.Result != null)
{
var scheduledTask = taskToDo.Result;
var wFRunId = LHHelper.GetWFRunId(scheduledTask.Source);
_logger?.LogDebug($"Received task schedule request for wfRun {wFRunId.Id}");
var wFRunId = LHHelper.GetWfRunId(scheduledTask.Source);
_logger?.LogDebug($"Received task schedule request for wfRun {wFRunId?.Id}");

_connectionManager.SubmitTaskForExecution(scheduledTask, _client);

_logger?.LogDebug($"Scheduled task on threadpool for wfRun {wFRunId.Id}");
_logger?.LogDebug($"Scheduled task on threadpool for wfRun {wFRunId?.Id}");
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,24 @@ public class LHServerConnectionManager<T> : IDisposable
private const int MAX_REPORT_RETRIES = 5;

private LHConfig _config;
private MethodInfo _taskMethod;
private TaskDef _taskDef;
private List<VariableMapping> _mappings;
private T _executable;
private ILogger? _logger;
private LittleHorseClient _bootstrapClient;
private bool _running;
private List<LHServerConnection<T>> _runningConnections;
private Thread _rebalanceThread;
private SemaphoreSlim _semaphore;
private readonly SemaphoreSlim _semaphore;
private LHTask<T> _task;

public LHConfig Config { get { return _config; } }
public TaskDef TaskDef { get { return _taskDef; } }
public TaskDef TaskDef { get { return _task.TaskDef!; } }

public LHServerConnectionManager(LHConfig config,
MethodInfo taskMethod,
TaskDef taskDef,
List<VariableMapping> mappings,
T executable)
LHTask<T> task)
{
_config = config;
_taskMethod = taskMethod;
_taskDef = taskDef;
_mappings = mappings;
_executable = executable;
_logger = LHLoggerFactoryProvider.GetLogger<LHServerConnectionManager<T>>();

_bootstrapClient = config.GetGrcpClientInstance();
_task = task;
_bootstrapClient = config.GetGrpcClientInstance();

_running = false;
_runningConnections = new List<LHServerConnection<T>>();
Expand Down Expand Up @@ -83,25 +73,26 @@ private void DoHeartBeat()
{
try
{
_logger!.LogWarning($"Doing heartbeat... Task Worker ID: {_config.WorkerId}");
var request = new RegisterTaskWorkerRequest
{
TaskDefId = _taskDef.Id,
TaskDefId = _task.TaskDef!.Id,
TaskWorkerId = _config.WorkerId,
};

var response = _bootstrapClient.RegisterTaskWorker(request);

HandleRegisterTaskWorkResponse(response);

HandleRegisterTaskWorkerResponse(response);
}
catch (Exception ex)
{
_logger?.LogError(ex, $"Failed contacting bootstrap host {_config.BootstrapHost}:{_config.BootstrapPort}");
}
}

private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response)
private void HandleRegisterTaskWorkerResponse(RegisterTaskWorkerResponse response)
{

response.YourHosts.ToList().ForEach(host =>
{
if (!IsAlreadyRunning(host))
Expand All @@ -111,7 +102,7 @@ private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response)
var newConnection = new LHServerConnection<T>(this, host);
newConnection.Connect();
_runningConnections.Add(newConnection);
_logger?.LogInformation($"Adding connection to: {host.Host}:{host.Port} for task '{_taskDef.Id}'");
_logger?.LogInformation($"Adding connection to: {host.Host}:{host.Port} for task '{_task.TaskDef!.Id}'");
}
catch (IOException ex)
{
Expand Down Expand Up @@ -156,28 +147,33 @@ public async void SubmitTaskForExecution(ScheduledTask scheduledTask, LittleHors
private void DoTask(ScheduledTask scheduledTask, LittleHorseClient client)
{
ReportTaskRun result = ExecuteTask(scheduledTask, LHMappingHelper.MapDateTimeFromProtoTimeStamp(scheduledTask.CreatedAt));
_semaphore.Release();

var wfRunId = LHHelper.GetWFRunId(scheduledTask.Source);
var wfRunId = LHHelper.GetWfRunId(scheduledTask.Source);

try
{
var retriesLeft = MAX_REPORT_RETRIES;

_logger?.LogDebug($"Going to report task for wfRun {wfRunId.Id}");
_logger?.LogDebug($"Going to report task for wfRun {wfRunId?.Id}");
Policy.Handle<Exception>().WaitAndRetry(MAX_REPORT_RETRIES,
retryAttempt => TimeSpan.FromSeconds(5),
onRetry: (exception, timeSpan, retryCount, context) =>
{
--retriesLeft;
_logger?.LogDebug($"Failed to report task for wfRun {wfRunId}: {exception.Message}. Retries left: {retriesLeft}");
_logger?.LogDebug($"Retrying reportTask rpc on taskRun {LHHelper.TaskRunIdToString(result.TaskRunId)}");
}).Execute(() => RunReportTask(result));
{
--retriesLeft;
_logger?.LogDebug(
$"Failed to report task for wfRun {wfRunId}: {exception.Message}. Retries left: {retriesLeft}");
_logger?.LogDebug(
$"Retrying reportTask rpc on taskRun {LHHelper.TaskRunIdToString(result.TaskRunId)}");
}).Execute(() => RunReportTask(result));
}
catch (Exception ex)
{
_logger?.LogDebug($"Failed to report task for wfRun {wfRunId}: {ex.Message}. No retries left.");
}
finally
{
_semaphore.Release();
}
}

private void RunReportTask(ReportTaskRun reportedTask)
Expand Down Expand Up @@ -278,9 +274,9 @@ private ReportTaskRun ExecuteTask(ScheduledTask scheduledTask, DateTime? schedul

private object? Invoke(ScheduledTask scheduledTask, LHWorkerContext workerContext)
{
var inputs = _mappings.Select(mapping => mapping.Assign(scheduledTask, workerContext)).ToArray();
var inputs = _task.TaskMethodMappings.Select(mapping => mapping.Assign(scheduledTask, workerContext)).ToArray();

return _taskMethod.Invoke(_executable, inputs);
return _task.TaskMethod!.Invoke(_task.Executable, inputs);
}

public void CloseConnection(LHServerConnection<T> connection)
Expand Down
Loading

0 comments on commit 5009c00

Please sign in to comment.