Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve race conditions in MQTT connection logic #1060

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/test_redirect
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ log bin/reset_config $site_path $project_spec $device_id shutdown_config.json
bin/reset_config $site_path $project_spec $device_id shutdown_config.json

log And let it settle for last start termination...
sleep 120
sleep 125

tail out/pubber.log.2

Expand Down
41 changes: 12 additions & 29 deletions pubber/src/main/java/daq/pubber/Pubber.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.http.ConnectionClosedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import udmi.lib.base.MqttDevice;
import udmi.lib.base.MqttPublisher.PublisherException;
import udmi.lib.client.DeviceManager;
import udmi.lib.client.SystemManager;
import udmi.lib.intf.FamilyProvider;
Expand Down Expand Up @@ -103,7 +101,6 @@ public class Pubber extends PubberManager implements PubberUdmiPublisher {
private SchemaVersion targetSchema;
private int deviceUpdateCount = -1;
private PubberDeviceManager deviceManager;
private boolean isConnected;
private boolean isGatewayDevice;

/**
Expand Down Expand Up @@ -393,7 +390,7 @@ private void processDeviceMetadata(Metadata metadata) {
}

@Override
public void periodicUpdate() {
public synchronized void periodicUpdate() {
try {
deviceUpdateCount++;
checkSmokyFailure();
Expand All @@ -403,6 +400,7 @@ public void periodicUpdate() {
flushDirtyState();
} catch (Exception e) {
error("Fatal error during execution", e);
resetConnection(getWorkingEndpoint());
}
}

Expand All @@ -427,18 +425,13 @@ public void startConnection(Function<String, Boolean> connectionDone) {

private boolean attemptConnection() {
try {
isConnected = false;
deviceManager.stop();
super.stop();
if (deviceTarget == null || !deviceTarget.isActive()) {
error("Mqtt publisher not active");
disconnectMqtt();
initializeMqtt();
}
disconnectMqtt();
initializeMqtt();
registerMessageHandlers();
connect();
configLatchWait();
isConnected = true;
deviceManager.activate();
return true;
} catch (Exception e) {
Expand Down Expand Up @@ -515,22 +508,14 @@ public byte[] ensureKeyBytes() {
}

@Override
public void publisherException(Exception toReport) {
if (toReport instanceof PublisherException report) {
publisherHandler(report.getType(), report.getPhase(), report.getCause(),
report.getDeviceId());
} else if (toReport instanceof ConnectionClosedException) {
error("Connection closed, attempting reconnect...");
while (retriesRemaining.getAndDecrement() > 0) {
if (attemptConnection()) {
return;
}
public synchronized void reconnect() {
while (retriesRemaining.getAndDecrement() > 0) {
if (attemptConnection()) {
return;
}
error("Connection retry failed, giving up.");
deviceManager.systemLifecycle(SystemMode.TERMINATE);
} else {
error("Unknown exception type " + toReport.getClass(), toReport);
}
error("Connection retry failed, giving up.");
deviceManager.systemLifecycle(SystemMode.TERMINATE);
}

@Override
Expand All @@ -541,12 +526,10 @@ public void persistEndpoint(EndpointConfiguration endpoint) {
}

@Override
public void resetConnection(String targetEndpoint) {
public synchronized void resetConnection(String targetEndpoint) {
try {
config.endpoint = fromJsonString(targetEndpoint,
EndpointConfiguration.class);
disconnectMqtt();
initializeMqtt();
retriesRemaining.set(CONNECT_RETRIES);
startConnection(connectionDone);
} catch (Exception e) {
Expand Down Expand Up @@ -700,7 +683,7 @@ public void setConfigLatch(CountDownLatch countDownLatch) {

@Override
public boolean isConnected() {
return isConnected;
return deviceTarget != null && deviceTarget.isActive();
}

@Override
Expand Down
6 changes: 4 additions & 2 deletions pubber/src/main/java/daq/pubber/PubberGatewayManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import udmi.lib.ProtocolFamily;
import udmi.lib.client.GatewayManager;
import udmi.lib.client.ProxyDeviceHost;
Expand Down Expand Up @@ -47,8 +48,9 @@ public void setMetadata(Metadata metadata) {

@Override
public void activate() {
ifNotNullThen(proxyDevices, p -> p.values()
.parallelStream().forEach(ProxyDeviceHost::activate));
ifNotNullThen(proxyDevices, p -> CompletableFuture.runAsync(() -> p.values()
.parallelStream()
.forEach(ProxyDeviceHost::activate)));
}

@Override
Expand Down
39 changes: 27 additions & 12 deletions pubber/src/main/java/daq/pubber/PubberUdmiPublisher.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@
import java.util.concurrent.locks.Lock;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.http.ConnectionClosedException;
import udmi.lib.base.GatewayError;
import udmi.lib.base.MqttDevice;
import udmi.lib.base.MqttPublisher.FakeTopic;
import udmi.lib.base.MqttPublisher.InjectedMessage;
import udmi.lib.base.MqttPublisher.InjectedState;
import udmi.lib.base.MqttPublisher.PublisherException;
import udmi.lib.client.DeviceManager;
import udmi.lib.client.PointsetManager;
import udmi.lib.client.PointsetManager.ExtraPointsetEvent;
Expand Down Expand Up @@ -242,9 +244,9 @@ default void captureExceptions(String action, Runnable runnable) {
*/
default void disconnectMqtt() {
if (getDeviceTarget() != null) {
captureExceptions("closing mqtt publisher", () -> getDeviceTarget().close());
captureExceptions("shutting down mqtt publisher executor",
captureExceptions("Shutting down MQTT publisher executor",
() -> getDeviceTarget().shutdown());
captureExceptions("Closing MQTT publisher", () -> getDeviceTarget().close());
setDeviceTarget(null);
}
}
Expand Down Expand Up @@ -805,16 +807,12 @@ default void publishSynchronousState() {
}
}

default boolean publisherActive() {
return getDeviceTarget() != null && getDeviceTarget().isActive();
}

/**
* Publishes the current device state as a message to the publisher if the publisher is active. If
* the publisher is not active, it marks the state as dirty and returns without publishing.
*/
default void publishStateMessage() {
if (!publisherActive()) {
if (!isConnected()) {
markStateDirty(-1);
return;
}
Expand Down Expand Up @@ -898,8 +896,8 @@ private void publishDeviceMessage(String targetId, Object message) {
* configured.
*/
default void publishDeviceMessage(String targetId, Object message, Runnable callback) {
if (getDeviceTarget() == null) {
error("publisher not active");
if (!isConnected()) {
error(format("Publisher not active (%s)", targetId));
return;
}
String topicSuffix = MESSAGE_TOPIC_SUFFIX_MAP.get(message.getClass());
Expand Down Expand Up @@ -989,6 +987,10 @@ default void debug(String message, String detail) {

void startConnection(Function<String, Boolean> connectionDone);

void reconnect();

void resetConnection(String targetEndpoint);

/**
* Flushes the dirty state by publishing an asynchronous state change.
*/
Expand All @@ -1000,12 +1002,25 @@ default void flushDirtyState() {

byte[] ensureKeyBytes();

void publisherException(Exception toReport);
/**
* Handles exceptions related to the publisher and
* takes appropriate actions based on the exception type.
*
* @param toReport the exception to be handled;
*/
default void publisherException(Exception toReport) {
if (toReport instanceof PublisherException r) {
publisherHandler(r.getType(), r.getPhase(), r.getCause(), r.getDeviceId());
} else if (toReport instanceof ConnectionClosedException) {
warn("Connection closed, attempting reconnect...");
reconnect();
} else {
error("Unknown exception type " + toReport.getClass(), toReport);
}
}

void persistEndpoint(EndpointConfiguration endpoint);

void resetConnection(String targetEndpoint);

String traceTimestamp(String messageBase);

/**
Expand Down
77 changes: 40 additions & 37 deletions pubber/src/main/java/udmi/lib/base/MqttPublisher.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import javax.net.SocketFactory;
import javax.net.ssl.SSLSocketFactory;
Expand Down Expand Up @@ -100,6 +101,7 @@ public class MqttPublisher implements Publisher {

private final Map<String, MqttClient> mqttClients = new ConcurrentHashMap<>();
private final Map<String, Instant> reauthTimes = new ConcurrentHashMap<>();
ReentrantLock reconnectLock = new ReentrantLock();

private final ExecutorService publisherExecutor =
Executors.newFixedThreadPool(PUBLISH_THREAD_COUNT);
Expand Down Expand Up @@ -215,22 +217,32 @@ private void publishCore(String deviceId, String topicSuffix, Object data, Runna
callback.run();
}
} catch (Exception e) {
if (!isActive()) {
return;
}
errorCounter.incrementAndGet();
warn(format("Publish %s failed for %s: %s", topicSuffix, deviceId, e));
if (getGatewayId() == null) {
closeMqttClient(deviceId);
if (mqttClients.isEmpty()) {
warn("Last client closed, shutting down connection.");
close();
shutdown();
reconnect();
}
} else if (getGatewayId().equals(deviceId)) {
reconnect();
}
}
}

private synchronized void reconnect() {
if (isActive()) {
if (reconnectLock.tryLock()) {
try {
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
} finally {
reconnectLock.unlock();
}
} else if (getGatewayId().equals(deviceId)) {
close();
shutdown();
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
}
}
}
Expand Down Expand Up @@ -268,7 +280,7 @@ private void closeMqttClient(String deviceId) {
if (removed != null) {
try {
if (removed.isConnected()) {
removed.disconnect();
removed.disconnectForcibly();
}
removed.close();
} catch (Exception e) {
Expand Down Expand Up @@ -298,7 +310,7 @@ public void close() {
@Override
public void shutdown() {
if (isActive()) {
publisherExecutor.shutdown();
publisherExecutor.shutdownNow();
}
}

Expand Down Expand Up @@ -532,7 +544,7 @@ private String getDeviceId(String topic) {
return topic.split("/")[splitIndex];
}

public void connect(String targetId, boolean clean) {
public synchronized void connect(String targetId, boolean clean) {
ifTrueThen(clean, () -> closeMqttClient(targetId));
getConnectedClient(targetId);
}
Expand Down Expand Up @@ -569,8 +581,10 @@ private boolean sendMessage(String deviceId, String mqttTopic,
return true;
}

private MqttClient getActiveClient(String targetId) {
checkAuthentication(targetId);
private synchronized MqttClient getActiveClient(String targetId) {
if (!checkAuthentication(targetId)) {
return null;
}
MqttClient client = getConnectedClient(targetId);
if (client.isConnected()) {
return client;
Expand All @@ -586,24 +600,16 @@ private void safeSleep(long timeoutMs) {
}
}

private void checkAuthentication(String targetId) {
private boolean checkAuthentication(String targetId) {
String authId = ofNullable(getGatewayId()).orElse(targetId);
Instant reAuthTime = reauthTimes.get(authId);
if (reAuthTime == null || Instant.now().isBefore(reAuthTime)) {
return;
return true;
}
warn("Authentication retry time reached for " + authId);
reauthTimes.remove(authId);
synchronized (mqttClients) {
try {
close();
shutdown();
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
} catch (Exception e) {
throw new RuntimeException("While trying to reconnect mqtt client", e);
}
}
reconnect();
return false;
}

private MqttClient getConnectedClient(String deviceId) {
Expand Down Expand Up @@ -721,26 +727,23 @@ private class MqttCallbackHandler implements MqttCallback {

@Override
public void connectionLost(Throwable cause) {
boolean connected = cleanClients(deviceId).isConnected();
warn("MQTT Connection Lost: " + connected + cause);
close();
shutdown();
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
if (isActive()) {
boolean connected = cleanClients(deviceId).isConnected();
warn(format("MQTT Connection Lost: %s %s", connected, cause));
reconnect();
}
}

@Override
public void deliveryComplete(IMqttDeliveryToken token) {
}

@Override
public void messageArrived(String topic, MqttMessage message) {
synchronized (MqttPublisher.this) {
try {
messageArrivedCore(topic, message);
} catch (Exception e) {
error("While processing message", deviceId, null, "handle", e);
}
public synchronized void messageArrived(String topic, MqttMessage message) {
try {
messageArrivedCore(topic, message);
} catch (Exception e) {
error("While processing message", deviceId, null, "handle", e);
}
}

Expand Down
Loading