diff --git a/cdap-app-fabric/pom.xml b/cdap-app-fabric/pom.xml
index 93489acfa9e7..63527d176c85 100644
--- a/cdap-app-fabric/pom.xml
+++ b/cdap-app-fabric/pom.xml
@@ -149,6 +149,7 @@
junit
junit
+ 4.13.2
test
diff --git a/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisher.java b/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisher.java
index 20067a04f654..c1cba7557f60 100644
--- a/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisher.java
+++ b/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisher.java
@@ -42,6 +42,7 @@
import io.cdap.cdap.proto.id.ProgramRunId;
import io.cdap.cdap.proto.id.TopicId;
import java.io.IOException;
+import java.net.SocketTimeoutException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -142,9 +143,7 @@ public void publish(Notification.Type notificationType, Map prop
.build());
LOG.trace("Published program status notification: {}", programStatusNotification);
done = true;
- } catch (IOException | AccessException e) {
- throw Throwables.propagate(e);
- } catch (TopicNotFoundException | ServiceUnavailableException e) {
+ } catch (TopicNotFoundException | ServiceUnavailableException | SocketTimeoutException e) {
// These exceptions are retry-able due to TMS not completely started
if (startTime < 0) {
startTime = System.currentTimeMillis();
@@ -164,6 +163,8 @@ public void publish(Notification.Type notificationType, Map prop
Thread.currentThread().interrupt();
done = true;
}
+ } catch (AccessException | IOException e) {
+ throw Throwables.propagate(e);
}
}
}
diff --git a/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisherTest.java b/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisherTest.java
index 9ffa8a7a77eb..a6c1e25b98bf 100644
--- a/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisherTest.java
+++ b/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/program/MessagingProgramStatePublisherTest.java
@@ -29,6 +29,7 @@
import io.cdap.cdap.proto.id.ApplicationId;
import io.cdap.cdap.proto.id.ProgramRunId;
import java.io.IOException;
+import java.net.SocketTimeoutException;
import java.util.Map;
import org.junit.Assert;
import org.junit.Rule;
@@ -95,4 +96,65 @@ public void testMultipleTopics_noRun() throws TopicNotFoundException, IOExceptio
StoreRequest storeRequest = storeRequestCaptor.getValue();
Assert.assertEquals("programstatusevent0", storeRequest.getTopicId().getTopic());
}
+
+ @Test
+ public void testPublishSuccessOnRetryableException() throws TopicNotFoundException, IOException {
+ CConfiguration cConf = CConfiguration.create();
+ cConf.setInt(Constants.AppFabric.PROGRAM_STATUS_EVENT_NUM_PARTITIONS, 1);
+ cConf.setInt("system.program.state.retry.policy.max.retries", 3);
+
+ MessagingService messagingService = Mockito.mock(MessagingService.class);
+ MessagingProgramStatePublisher publisher = new MessagingProgramStatePublisher(cConf, messagingService);
+ Mockito.when(messagingService.publish(Mockito.any()))
+ .thenThrow(new SocketTimeoutException())
+ .thenReturn(null);
+
+ publisher.publish(Notification.Type.PROGRAM_STATUS, ImmutableMap.of());
+
+ Mockito.verify(messagingService, Mockito.times(2)).publish(storeRequestCaptor.capture());
+ StoreRequest storeRequest = storeRequestCaptor.getValue();
+ Assert.assertEquals("programstatusevent", storeRequest.getTopicId().getTopic());
+ }
+
+ @Test
+ public void testPublishThrowsOnRetryExhausted() throws TopicNotFoundException, IOException {
+ CConfiguration cConf = CConfiguration.create();
+ cConf.setInt(Constants.AppFabric.PROGRAM_STATUS_EVENT_NUM_PARTITIONS, 1);
+ cConf.setInt("system.program.state.retry.policy.max.retries", 3);
+
+ MessagingService messagingService = Mockito.mock(MessagingService.class);
+ MessagingProgramStatePublisher publisher = new MessagingProgramStatePublisher(cConf, messagingService);
+ Mockito.when(messagingService.publish(Mockito.any()))
+ .thenThrow(new SocketTimeoutException());
+
+ RuntimeException outerException = Assert.assertThrows(RuntimeException.class,
+ () -> publisher.publish(Notification.Type.PROGRAM_STATUS, ImmutableMap.of()));
+ Assert.assertNotNull(outerException.getCause());
+ Assert.assertTrue(outerException.getCause() instanceof SocketTimeoutException);
+
+ Mockito.verify(messagingService, Mockito.times(4)).publish(storeRequestCaptor.capture());
+ StoreRequest storeRequest = storeRequestCaptor.getValue();
+ Assert.assertEquals("programstatusevent", storeRequest.getTopicId().getTopic());
+ }
+
+ @Test
+ public void testPublishThrowsForNonRetryableException() throws TopicNotFoundException, IOException {
+ CConfiguration cConf = CConfiguration.create();
+ cConf.setInt(Constants.AppFabric.PROGRAM_STATUS_EVENT_NUM_PARTITIONS, 1);
+
+ MessagingService messagingService = Mockito.mock(MessagingService.class);
+ MessagingProgramStatePublisher publisher = new MessagingProgramStatePublisher(cConf,
+ messagingService);
+ Mockito.when(messagingService.publish(Mockito.any()))
+ .thenThrow(new IOException())
+ .thenReturn(null);
+
+ RuntimeException outerException = Assert.assertThrows(RuntimeException.class,
+ () -> publisher.publish(Notification.Type.PROGRAM_STATUS, ImmutableMap.of()));
+ Assert.assertNotNull(outerException.getCause());
+ Assert.assertTrue(outerException.getCause() instanceof IOException);
+ Mockito.verify(messagingService).publish(storeRequestCaptor.capture());
+ StoreRequest storeRequest = storeRequestCaptor.getValue();
+ Assert.assertEquals("programstatusevent", storeRequest.getTopicId().getTopic());
+ }
}