Skip to content

Commit af3020a

Browse files
authored
Add get_classical_any_host() and tests (#160)
1 parent 187b861 commit af3020a

File tree

3 files changed

+241
-0
lines changed

3 files changed

+241
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import unittest
2+
from qunetsim.components import Host, Network
3+
from qunetsim.backends import EQSNBackend
4+
import time
5+
6+
network = Network.get_instance()
7+
hosts = {}
8+
9+
class TestGetClassicalAnyHost(unittest.TestCase):
10+
11+
MAX_WAIT_TIME = 10
12+
13+
@classmethod
14+
def setUpClass(cls):
15+
global network
16+
global hosts
17+
nodes = ["Alice", "Bob"]
18+
backend = EQSNBackend()
19+
network.start(nodes=nodes, backend=backend)
20+
hosts = {'alice': Host('Alice', backend),
21+
'bob': Host('Bob', backend)}
22+
hosts['alice'].add_connection('Bob')
23+
hosts['bob'].add_connection('Alice')
24+
hosts['alice'].start()
25+
hosts['bob'].start()
26+
for h in hosts.values():
27+
network.add_host(h)
28+
29+
def setUp(self) -> None:
30+
hosts['bob']._classical_messages.empty()
31+
hosts['alice']._classical_messages.empty()
32+
33+
@classmethod
34+
def tearDownClass(cls):
35+
global network
36+
global hosts
37+
network.stop(stop_hosts=True)
38+
39+
40+
41+
def test_get_all_with_wait_time(self):
42+
def listen_with_wait_time(s):
43+
time.sleep(2)
44+
msgs = hosts['bob'].get_classical_any_host(seq_num=None, wait=self.MAX_WAIT_TIME)
45+
self.assertEqual([x.content for x in msgs],['3','2','1'])
46+
47+
def send_some_with_delay(s):
48+
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)
49+
hosts['alice'].send_classical(hosts['bob'].host_id,str(2),await_ack=True)
50+
time.sleep(5)
51+
hosts['alice'].send_classical(hosts['bob'].host_id,str(3),await_ack=True)
52+
53+
t1 = hosts['bob'].run_protocol(listen_with_wait_time)
54+
t2 = hosts['alice'].run_protocol(send_some_with_delay)
55+
56+
t1.join()
57+
t2.join()
58+
59+
def test_get_seq_with_wait_time(self):
60+
def listen_with_wait_time(s):
61+
time.sleep(2)
62+
msg = hosts['bob'].get_classical_any_host(seq_num=2, wait=self.MAX_WAIT_TIME)
63+
self.assertEqual(msg.content,'3')
64+
65+
def send_some_with_delay(s):
66+
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)
67+
hosts['alice'].send_classical(hosts['bob'].host_id,str(2),await_ack=True)
68+
time.sleep(5)
69+
hosts['alice'].send_classical(hosts['bob'].host_id,str(3),await_ack=True)
70+
71+
t1 = hosts['bob'].run_protocol(listen_with_wait_time)
72+
t2 = hosts['alice'].run_protocol(send_some_with_delay)
73+
74+
t1.join()
75+
t2.join()
76+
77+
def test_get_seq_with_wait_time_none_value(self):
78+
def listen_with_wait_time(s):
79+
time.sleep(2)
80+
msg = hosts['bob'].get_classical_any_host(seq_num=3, wait=self.MAX_WAIT_TIME)
81+
self.assertEqual(msg,None)#seq_num not present
82+
83+
def send_some_with_delay(s):
84+
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)
85+
hosts['alice'].send_classical(hosts['bob'].host_id,str(2),await_ack=True)
86+
time.sleep(5)
87+
hosts['alice'].send_classical(hosts['bob'].host_id,str(3),await_ack=True)
88+
89+
t1 = hosts['bob'].run_protocol(listen_with_wait_time)
90+
t2 = hosts['alice'].run_protocol(send_some_with_delay)
91+
92+
t1.join()
93+
t2.join()
94+
95+
def test_get_all_with_wait_time_empty_arr(self):
96+
def listen_with_wait_time(s):
97+
msgs = hosts['bob'].get_classical_any_host(None, wait=self.MAX_WAIT_TIME)
98+
self.assertEqual(msgs,[])
99+
100+
def send_after_max_wait(s):
101+
time.sleep(self.MAX_WAIT_TIME+2)
102+
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)
103+
104+
t1 = hosts['bob'].run_protocol(listen_with_wait_time)
105+
t2 = hosts['alice'].run_protocol(send_after_max_wait)
106+
107+
t1.join()
108+
t2.join()
109+
110+
def test_get_all_no_wait_time(self):
111+
# no msgs yet
112+
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
113+
rec_msgs = hosts['bob'].get_classical_any_host(None, wait=0)
114+
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
115+
self.assertEqual(len(rec_msgs), 0)
116+
117+
# with some msgs
118+
for c in range(5):
119+
hosts['alice'].send_classical(hosts['bob'].host_id, str(c), await_ack=True)
120+
rec_msgs = hosts['bob'].get_classical_any_host(None, wait=0)
121+
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, 'Alice')
122+
self.assertEqual(len(rec_msgs), 5)
123+
124+
def test_get_seq_no_wait_time(self):
125+
# no msgs yet
126+
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
127+
rec_msg = hosts['bob'].get_classical_any_host(0, wait=0)
128+
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
129+
self.assertEqual(rec_msg, None)
130+
131+
# with some msgs
132+
for c in range(5):
133+
hosts['alice'].send_classical(hosts['bob'].host_id, str(c), await_ack=True)
134+
rec_msg = hosts['bob'].get_classical_any_host(4, wait=0)
135+
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, 'Alice')
136+
self.assertEqual(rec_msg.content, '4')
137+
138+
def test_wait_data_type(self):
139+
self.assertRaises(Exception, hosts['bob'].get_classical_any_host, None, "1")
140+
141+

qunetsim/components/host.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,16 @@ def get_classical(self, host_id, seq_num=None, wait=0):
14541454
cla = self._classical_messages.get_all_from_sender(host_id, wait)
14551455
return sorted(cla, key=lambda x: x.seq_num, reverse=True)
14561456

1457+
def get_classical_any_host(self, seq_num=None, wait=0):
1458+
if not isinstance(wait, float) and not isinstance(wait, int):
1459+
raise Exception('wait parameter should be a number')
1460+
1461+
if seq_num is not None:
1462+
return self._classical_messages.get_with_seq_num_from_any_sender(seq_num,wait)
1463+
1464+
cla = self._classical_messages.get_all_from_any_sender(wait)
1465+
return sorted(cla, key=lambda x: x.seq_num, reverse=True)
1466+
14571467
def get_next_classical(self, sender_id, wait=-1):
14581468
"""
14591469
Gets the next classical message available from a sender.

qunetsim/objects/storage/classical_storage.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ class ClassicalStorage(object):
1111
GET_NEXT = 1
1212
GET_ALL = 2
1313
GET_WITH_SEQ_NUM = 3
14+
GET_ALL_MSGS_ANY_HOST = 4
15+
GET_WITH_SEQ_NUM_ANY_HOST = 5
1416

1517
def __init__(self):
1618
self._host_to_msg_dict = {}
1719
self._host_to_read_index = {}
20+
self.last_msg_added_to_host = None
1821

1922
# read write lock, for threaded access
2023
self._lock = RWLock()
@@ -43,6 +46,12 @@ def _check_all_requests(self):
4346
ret = self._get_all_from_sender(args[1])
4447
elif args[2] == ClassicalStorage.GET_WITH_SEQ_NUM:
4548
ret = self._get_with_seq_num_from_sender(args[1], args[3])
49+
elif args[2] == ClassicalStorage.GET_ALL_MSGS_ANY_HOST:
50+
ret = self._get_all_from_sender(self.last_msg_added_to_host) \
51+
if self.last_msg_added_to_host is not None else None
52+
elif args[2] == ClassicalStorage.GET_WITH_SEQ_NUM_ANY_HOST:
53+
ret = self._get_with_seq_num_from_sender(self.last_msg_added_to_host, args[3]) \
54+
if self.last_msg_added_to_host is not None else None
4655
else:
4756
raise ValueError("Internal Error, this request does not exist!")
4857

@@ -84,6 +93,7 @@ def empty(self):
8493
self._lock.acquire_write()
8594
self._host_to_msg_dict = {}
8695
self._host_to_read_index = {}
96+
self.last_msg_added_to_host = None
8797
self._lock.release_write()
8898

8999
def _add_new_host_id(self, host_id):
@@ -129,6 +139,7 @@ def add_msg_to_storage(self, message):
129139
if sender_id not in list(self._host_to_msg_dict):
130140
self._add_new_host_id(sender_id)
131141
self._host_to_msg_dict[sender_id].append(message)
142+
self.last_msg_added_to_host = sender_id
132143
self._check_all_requests()
133144
self._lock.release_write()
134145

@@ -269,6 +280,85 @@ def _get_with_seq_num_from_sender(self, sender_id, seq_num):
269280
return None
270281
msg = self._host_to_msg_dict[sender_id][seq_num]
271282
return msg
283+
284+
def get_all_from_any_sender(self,wait=0):
285+
"""
286+
Get all stored messages from any sender. If delete option is set,
287+
the returned messages are removed from the storage.
288+
289+
Args:
290+
wait (int): Default is 0. The maximum blocking time. -1 to block forever.
291+
292+
Returns:
293+
List of messages of the sender. If there are none, an empty list is
294+
returned.
295+
"""
296+
297+
# Block forever if wait is -1
298+
if wait == -1:
299+
wait = None
300+
301+
self._lock.acquire_write()
302+
msg = None
303+
if self.last_msg_added_to_host is not None:
304+
msg = self.get_all_from_sender(self.last_msg_added_to_host)
305+
306+
if wait == 0:
307+
self._lock.release_write()
308+
return msg if msg is not None else []
309+
310+
q = queue.Queue()
311+
request = [q, None, ClassicalStorage.GET_ALL_MSGS_ANY_HOST]
312+
req_id = self._add_request(request)
313+
self._lock.release_write()
314+
315+
try:
316+
msg = q.get(timeout=wait)
317+
except queue.Empty:
318+
pass
319+
320+
321+
if msg is None:
322+
self._lock.acquire_write()
323+
self._remove_request(req_id)
324+
self._lock.release_write()
325+
return []
326+
return msg
327+
328+
def get_with_seq_num_from_any_sender(self, seq_num, wait=0):
329+
'''
330+
Returns:
331+
Message object, if such a message exists, or none.
332+
'''
333+
# Block forever if wait is -1
334+
if wait == -1:
335+
wait = None
336+
337+
338+
self._lock.acquire_write()
339+
next_msg = None
340+
if self.last_msg_added_to_host is not None:
341+
next_msg = self.get_with_seq_num_from_sender(self.last_msg_added_to_host,seq_num)
342+
343+
if wait == 0:
344+
self._lock.release_write()
345+
return next_msg
346+
347+
q = queue.Queue()
348+
request = [q, None, ClassicalStorage.GET_WITH_SEQ_NUM_ANY_HOST, seq_num]
349+
req_id = self._add_request(request)
350+
self._lock.release_write()
351+
352+
try:
353+
next_msg = q.get(timeout=wait)
354+
except queue.Empty:
355+
pass
356+
357+
if next_msg is None:
358+
self._lock.acquire_write()
359+
self._remove_request(req_id)
360+
self._lock.release_write()
361+
return next_msg
272362

273363
def get_all(self):
274364
"""

0 commit comments

Comments
 (0)