Skip to content

Commit

Permalink
Use a safe eventLoop wrapper to avoid thread shutdown and fail the ta…
Browse files Browse the repository at this point in the history
…sk correctly
  • Loading branch information
shangm2 committed Jan 23, 2025
1 parent 15fd5d8 commit 1cf4320
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 19 deletions.
5 changes: 5 additions & 0 deletions presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,11 @@
<artifactId>netty-transport</artifactId>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-common</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
import com.sun.management.ThreadMXBean;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.netty.channel.EventLoop;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import org.joda.time.DateTime;

Expand Down Expand Up @@ -143,7 +142,7 @@
* This class now uses an event loop concurrency model to eliminate the need for explicit synchronization:
* <ul>
* <li>All mutable state access and modifications are performed on a single dedicated event loop thread</li>
* <li>External threads submit operations to the event loop using {@code taskEventLoop.execute()}</li>
* <li>External threads submit operations to the event loop using {@code safeExecuteOnEventLoop()}</li>
* <li>The event loop serializes all operations, eliminating race conditions without using locks</li>
* </ul>
* <p>
Expand Down Expand Up @@ -230,9 +229,9 @@ public final class HttpRemoteTask
private final DecayCounter taskUpdateRequestSize;
private final SchedulerStatsTracker schedulerStatsTracker;

private final EventLoop taskEventLoop;
private final HttpRemoteTaskFactory.SafeEventLoop taskEventLoop;

public HttpRemoteTask(
public static HttpRemoteTask createHttpRemoteTask(
Session session,
TaskId taskId,
String nodeId,
Expand Down Expand Up @@ -267,7 +266,82 @@ public HttpRemoteTask(
HandleResolver handleResolver,
ConnectorTypeSerdeManager connectorTypeSerdeManager,
SchedulerStatsTracker schedulerStatsTracker,
EventLoop taskEventLoop)
HttpRemoteTaskFactory.SafeEventLoop taskEventLoop)
{
HttpRemoteTask task = new HttpRemoteTask(session,
taskId,
nodeId,
location,
remoteLocation,
planFragment,
initialSplits,
outputBuffers,
httpClient,
maxErrorDuration,
taskStatusRefreshMaxWait,
taskInfoRefreshMaxWait,
taskInfoUpdateInterval,
summarizeTaskInfo,
taskStatusCodec,
taskInfoCodec,
taskInfoJsonCodec,
taskUpdateRequestCodec,
planFragmentCodec,
metadataUpdatesCodec,
nodeStatsTracker,
stats,
binaryTransportEnabled,
thriftTransportEnabled,
taskInfoThriftTransportEnabled,
thriftProtocol,
tableWriteInfo,
maxTaskUpdateSizeInBytes,
metadataManager,
queryManager,
taskUpdateRequestSize,
handleResolver,
connectorTypeSerdeManager,
schedulerStatsTracker,
taskEventLoop);
task.initialize();
return task;
}

private HttpRemoteTask(Session session,
TaskId taskId,
String nodeId,
URI location,
URI remoteLocation,
PlanFragment planFragment,
Multimap<PlanNodeId, Split> initialSplits,
OutputBuffers outputBuffers,
HttpClient httpClient,
Duration maxErrorDuration,
Duration taskStatusRefreshMaxWait,
Duration taskInfoRefreshMaxWait,
Duration taskInfoUpdateInterval,
boolean summarizeTaskInfo,
Codec<TaskStatus> taskStatusCodec,
Codec<TaskInfo> taskInfoCodec,
Codec<TaskInfo> taskInfoJsonCodec,
Codec<TaskUpdateRequest> taskUpdateRequestCodec,
Codec<PlanFragment> planFragmentCodec,
Codec<MetadataUpdates> metadataUpdatesCodec,
NodeStatsTracker nodeStatsTracker,
RemoteTaskStats stats,
boolean binaryTransportEnabled,
boolean thriftTransportEnabled,
boolean taskInfoThriftTransportEnabled,
Protocol thriftProtocol,
TableWriteInfo tableWriteInfo,
int maxTaskUpdateSizeInBytes,
MetadataManager metadataManager,
QueryManager queryManager,
DecayCounter taskUpdateRequestSize,
HandleResolver handleResolver,
ConnectorTypeSerdeManager connectorTypeSerdeManager,
SchedulerStatsTracker schedulerStatsTracker,
HttpRemoteTaskFactory.SafeEventLoop taskEventLoop)
{
requireNonNull(session, "session is null");
requireNonNull(taskId, "taskId is null");
Expand Down Expand Up @@ -389,7 +463,11 @@ public HttpRemoteTask(
handleResolver,
connectorTypeSerdeManager,
thriftProtocol);
}

// this is a separate method to ensure that the `this` reference is not leaked during construction
private void initialize()
{
taskStatusFetcher.addStateChangeListener(newStatus -> {
verify(taskEventLoop.inEventLoop());

Expand All @@ -404,7 +482,7 @@ public HttpRemoteTask(
});

updateTaskStats();
taskEventLoop.execute(this::updateSplitQueueSpace);
safeExecuteOnEventLoop(this::updateSplitQueueSpace);
}

public PlanFragment getPlanFragment()
Expand Down Expand Up @@ -445,7 +523,7 @@ public URI getRemoteTaskLocation()
@Override
public void start()
{
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
// to start we just need to trigger an update
started = true;
scheduleUpdate();
Expand All @@ -465,7 +543,7 @@ public void addSplits(Multimap<PlanNodeId, Split> splitsBySource)
return;
}

taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
boolean updateNeeded = false;
for (Entry<PlanNodeId, Collection<Split>> entry : splitsBySource.asMap().entrySet()) {
PlanNodeId sourceId = entry.getKey();
Expand Down Expand Up @@ -502,7 +580,7 @@ public void addSplits(Multimap<PlanNodeId, Split> splitsBySource)
@Override
public void noMoreSplits(PlanNodeId sourceId)
{
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
if (noMoreSplits.containsKey(sourceId)) {
return;
}
Expand All @@ -516,7 +594,7 @@ public void noMoreSplits(PlanNodeId sourceId)
@Override
public void noMoreSplits(PlanNodeId sourceId, Lifespan lifespan)
{
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
if (pendingNoMoreSplitsForLifespan.put(sourceId, lifespan)) {
needsUpdate = true;
scheduleUpdate();
Expand All @@ -531,7 +609,7 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers)
return;
}

taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
if (newOutputBuffers.getVersion() > outputBuffers.getVersion()) {
outputBuffers = newOutputBuffers;
needsUpdate = true;
Expand Down Expand Up @@ -696,7 +774,7 @@ public ListenableFuture<?> whenSplitQueueHasSpace(long weightThreshold)
return immediateFuture(null);
}
SettableFuture<?> future = SettableFuture.create();
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
if (whenSplitQueueHasSpaceThreshold.isPresent()) {
checkArgument(weightThreshold == whenSplitQueueHasSpaceThreshold.getAsLong(), "Multiple split queue space notification thresholds not supported");
}
Expand Down Expand Up @@ -866,7 +944,7 @@ private void scheduleUpdate()

private void sendUpdate()
{
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
TaskStatus taskStatus = getTaskStatus();
// don't update if the task hasn't been started yet or if it is already finished
if (!started || !needsUpdate || taskStatus.getState().isDone()) {
Expand Down Expand Up @@ -987,7 +1065,7 @@ private TaskSource getSource(PlanNodeId planNodeId)
@Override
public void cancel()
{
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
TaskStatus taskStatus = getTaskStatus();
if (taskStatus.getState().isDone()) {
return;
Expand All @@ -1007,7 +1085,7 @@ public void cancel()

private void cleanUpTask()
{
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
checkState(getTaskStatus().getState().isDone(), "attempt to clean up a task that is not done yet");

// clear pending splits to free memory
Expand Down Expand Up @@ -1055,7 +1133,7 @@ public void abort()

private void abort(TaskStatus status)
{
taskEventLoop.execute(() -> {
safeExecuteOnEventLoop(() -> {
checkState(status.getState().isDone(), "cannot abort task with an incomplete status");

taskStatusFetcher.updateTaskStatus(status);
Expand Down Expand Up @@ -1317,4 +1395,19 @@ public void onFailure(Throwable throwable)
onFailureTaskInfo(throwable, this.action, this.request, this.cleanupBackoff);
}
}

/***
* Wrap the task execution on event loop to fail the entire task on any failure.
*/
private void safeExecuteOnEventLoop(Runnable r)
{
taskEventLoop.execute(() -> {
try {
r.run();
}
catch (Throwable t) {
failTask(t);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.json.smile.SmileCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.airlift.stats.DecayCounter;
import com.facebook.airlift.stats.ExponentialDecay;
import com.facebook.drift.codec.ThriftCodec;
Expand Down Expand Up @@ -49,20 +50,26 @@
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.airlift.units.Duration;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import org.weakref.jmx.Managed;

import javax.annotation.PreDestroy;
import javax.inject.Inject;

import java.util.concurrent.Executor;
import java.util.function.Consumer;

import static com.facebook.presto.server.thrift.ThriftCodecWrapper.wrapThriftCodec;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public class HttpRemoteTaskFactory
implements RemoteTaskFactory
{
private static final Logger log = Logger.get(HttpRemoteTaskFactory.class);
private final HttpClient httpClient;
private final LocationFactory locationFactory;
private final Codec<TaskStatus> taskStatusCodec;
Expand All @@ -89,7 +96,14 @@ public class HttpRemoteTaskFactory
private final QueryManager queryManager;
private final DecayCounter taskUpdateRequestSize;
private final EventLoopGroup eventLoopGroup = new DefaultEventLoopGroup(Runtime.getRuntime().availableProcessors(),
new ThreadFactoryBuilder().setNameFormat("task-event-loop-%s").setDaemon(true).build());
new ThreadFactoryBuilder().setNameFormat("task-event-loop-%s").setDaemon(true).build())
{
@Override
protected EventLoop newChild(Executor executor, Object... args)
{
return new SafeEventLoop(this, executor);
}
};

@Inject
public HttpRemoteTaskFactory(
Expand Down Expand Up @@ -194,7 +208,7 @@ public RemoteTask createRemoteTask(
TableWriteInfo tableWriteInfo,
SchedulerStatsTracker schedulerStatsTracker)
{
return new HttpRemoteTask(
return HttpRemoteTask.createHttpRemoteTask(
session,
taskId,
node.getNodeIdentifier(),
Expand Down Expand Up @@ -229,6 +243,39 @@ public RemoteTask createRemoteTask(
handleResolver,
connectorTypeSerdeManager,
schedulerStatsTracker,
eventLoopGroup.next());
(SafeEventLoop) eventLoopGroup.next());
}

/***
* One observation about event loop is if submitted task fails, it could kill the thread but the event loop group will not create a new one.
* Here, we wrap it as safe event loop so that if any submitted job fails, we chose to log the error and fail the entire task.
*/
static class SafeEventLoop
extends DefaultEventLoop
{
private Consumer<Throwable> onFail;

public SafeEventLoop(EventLoopGroup parent, Executor executor)
{
super(parent, executor);
}

@Override
protected void run()
{
do {
Runnable task = takeTask();
if (task != null) {
try {
runTask(task);
}
catch (Throwable t) {
log.error("Error executing task on event loop", t);
}
updateLastExecutionTime();
}
}
while (!this.confirmShutdown());
}
}
}

0 comments on commit 1cf4320

Please sign in to comment.