Skip to content

Commit

Permalink
Made MessageCache add/drain/iteration thread-safe (https://issues.red…
Browse files Browse the repository at this point in the history
  • Loading branch information
belaban committed Jan 22, 2024
1 parent 8fdac01 commit 5d7290d
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 48 deletions.
55 changes: 7 additions & 48 deletions src/org/jgroups/protocols/UNICAST3.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ public class UNICAST3 extends Protocol implements AgeOutCache.Handler<Address> {
/** Keep track of when a SEND_FIRST_SEQNO message was sent to a given sender */
protected ExpiryCache<Address> last_sync_sent=null;

// Queues messages until a {@link ReceiverEntry} has been created. Queued messages are then removed from
// the cache and added to the ReceiverEntry
protected final MessageCache msg_cache=new MessageCache();

protected static final Message DUMMY_OOB_MSG=new Message().setFlag(Message.Flag.OOB);
Expand Down Expand Up @@ -485,7 +487,7 @@ public void up(MessageBatch batch) {
if(hdr.first)
entry=getReceiverEntry(sender, hdr.seqno(), hdr.first, hdr.connId());
else if(entry == null) {
msg_cache.cache(sender, msg);
msg_cache.add(sender, msg);
log.trace("%s: cached %s#%d", local_addr, sender, hdr.seqno());
}
}
Expand All @@ -495,7 +497,7 @@ else if(entry == null) {
sendRequestForFirstSeqno(sender);
else {
if(!msg_cache.isEmpty()) { // quick and dirty check
List<Message> queued_msgs=msg_cache.drain(sender);
Collection<Message> queued_msgs=msg_cache.drain(sender);
if(queued_msgs != null)
addQueuedMessages(sender, entry, queued_msgs);
}
Expand Down Expand Up @@ -745,12 +747,12 @@ public void expired(Address key) {
protected void handleDataReceived(final Address sender, long seqno, short conn_id, boolean first, final Message msg) {
ReceiverEntry entry=getReceiverEntry(sender, seqno, first, conn_id);
if(entry == null) {
msg_cache.cache(sender, msg);
msg_cache.add(sender, msg);
log.trace("%s: cached %s#%d", local_addr, sender, seqno);
return;
}
if(!msg_cache.isEmpty()) { // quick and dirty check
List<Message> queued_msgs=msg_cache.drain(sender);
Collection<Message> queued_msgs=msg_cache.drain(sender);
if(queued_msgs != null)
addQueuedMessages(sender, entry, queued_msgs);
}
Expand Down Expand Up @@ -781,7 +783,7 @@ protected void addMessage(ReceiverEntry entry, Address sender, long seqno, Messa
}
}

protected void addQueuedMessages(final Address sender, final ReceiverEntry entry, List<Message> queued_msgs) {
protected void addQueuedMessages(final Address sender, final ReceiverEntry entry, Collection<Message> queued_msgs) {
for(Message msg: queued_msgs) {
UnicastHeader3 hdr=msg.getHeader(this.id);
if(hdr.conn_id != entry.conn_id) {
Expand Down Expand Up @@ -1444,47 +1446,4 @@ public String toString() {
}
}

/**
* Used to queue messages until a {@link ReceiverEntry} has been created. Queued messages are then removed from
* the cache and added to the ReceiverEntry
*/
protected class MessageCache {
private final Map<Address,List<Message>> map=new ConcurrentHashMap<>();
private volatile boolean is_empty=true;

protected MessageCache cache(Address sender, Message msg) {
List<Message> list=map.computeIfAbsent(sender, addr -> new ArrayList<>());
list.add(msg);
is_empty=false;
return this;
}

protected List<Message> drain(Address sender) {
List<Message> list=map.remove(sender);
if(map.isEmpty())
is_empty=true;
return list;
}

protected MessageCache clear() {
map.clear();
is_empty=true;
return this;
}

/** Returns a count of all messages */
protected int size() {
return map.values().stream().mapToInt(Collection::size).sum();
}

protected boolean isEmpty() {
return is_empty;
}

public String toString() {
return String.format("%d message(s)", size());
}
}


}
55 changes: 55 additions & 0 deletions src/org/jgroups/util/MessageCache.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package org.jgroups.util;

import org.jgroups.Address;
import org.jgroups.Message;

import java.util.Collection;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

/**
* A cache associating members and messages
* @author Bela Ban
* @since 5.3.2
*/
public class MessageCache {
protected final Map<Address,Queue<Message>> map=new ConcurrentHashMap<>();
protected volatile boolean is_empty=true;

public MessageCache add(Address sender, Message msg) {
Queue<Message> list=map.computeIfAbsent(sender, addr -> new ConcurrentLinkedQueue<>());
list.add(msg);
is_empty=false;
return this;
}

public Collection<Message> drain(Address sender) {
if(sender == null)
return null;
Queue<Message> queue=map.remove(sender);
if(map.isEmpty())
is_empty=true;
return queue;
}

public MessageCache clear() {
map.clear();
is_empty=true;
return this;
}

/** Returns a count of all messages */
public int size() {
return map.values().stream().mapToInt(Collection::size).sum();
}

public boolean isEmpty() {
return is_empty;
}

public String toString() {
return String.format("%d message(s)", size());
}
}
56 changes: 56 additions & 0 deletions tests/junit-functional/org/jgroups/tests/MessageCacheTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.jgroups.tests;

import org.jgroups.Address;
import org.jgroups.Global;
import org.jgroups.Message;
import org.jgroups.ObjectMessage;
import org.jgroups.util.MessageCache;
import org.jgroups.util.Util;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import java.util.Collection;

/**
* Tests {@link org.jgroups.util.MessageCache}
* @author Bela Ban
* @since 5.3.2
*/
@Test(groups= Global.FUNCTIONAL,singleThreaded=true)
public class MessageCacheTest {
protected MessageCache cache;
protected static final Address A=Util.createRandomAddress("A"), B=Util.createRandomAddress("B"),
C=Util.createRandomAddress("C");

@BeforeMethod protected void setup() {
cache=new MessageCache();
}

public void testCreation() {
assert cache.isEmpty();
}

public void testAdd() {
for(int i=1; i <= 5; i++) {
cache.add(A, new ObjectMessage(A, i));
cache.add(B, new ObjectMessage(B, i+10));
}
assert !cache.isEmpty();
assert cache.size() == 10;
}

public void testDrain() {
testAdd();
Collection<Message> list=cache.drain(null);
assert list == null;
list=cache.drain(C);
assert list == null;
list=cache.drain(B);
assert list.size() == 5;
assert cache.size() == 5;
assert !cache.isEmpty();
list=cache.drain(A);
assert list.size() == 5;
assert cache.isEmpty();
}
}

0 comments on commit 5d7290d

Please sign in to comment.