Skip to content

Commit 152314c

Browse files
committed
Merge pull request #5 from jbaayen/master
MMIClient: Simplify implementation, implement missing methods, remove s...
2 parents 3cdc3b3 + 93869dc commit 152314c

File tree

3 files changed

+183
-105
lines changed

3 files changed

+183
-105
lines changed

mmi/mmi_client.py

Lines changed: 171 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,151 +1,217 @@
11
import zmq
2-
import logging
32

43
from mmi import send_array, recv_array
54
from bmi.api import IBmi
65

7-
logger = logging.getLogger(__name__)
6+
class MMIClient(IBmi):
7+
def __init__(self, zmq_address):
8+
"""
9+
Constructor
10+
"""
811

12+
# Open ZeroMQ socket
13+
context = zmq.Context()
914

10-
class MMIClient(IBmi):
11-
def __init__(self, uuid, mmi_metadata):
12-
"""
13-
The 'database' has mmi module metadata.
14-
15-
the metadata must contain the key "ports"
16-
"ports": {'PUSH': 58452, 'REQ': 53956, 'SUB': 60285}
17-
"""
18-
logger.debug("Initializing MMI Client [%s]..." % uuid)
19-
self.uuid = uuid
20-
self.database = mmi_metadata
21-
self.ports = mmi_metadata['ports']
22-
23-
self.sockets = {}
24-
self.context = zmq.Context()
25-
26-
logger.debug("Connecting to push/pull server...")
27-
if 'PUSH' in self.ports:
28-
logger.debug("MMI PUSH is available")
29-
self.sockets['PUSH'] = self.context.socket(zmq.PUSH)
30-
# TODO: is this correct?
31-
url = 'tcp://%s:%d' % (self.database['node'], self.ports['PUSH'])
32-
self.sockets['PUSH'].connect(url)
33-
34-
if 'SUB' in self.ports:
35-
logger.debug("MMI SUB is available")
36-
self.sockets['SUB'] = self.context.socket(zmq.SUB)
37-
url = 'tcp://%s:%d' % (self.database['node'], self.ports['SUB'])
38-
self.sockets['SUB'].connect(url)
39-
40-
if 'REQ' in self.ports:
41-
logger.debug("MMI REQ is available")
42-
self.sockets['REQ'] = self.context.socket(zmq.REQ)
43-
url = 'tcp://%s:%d' % (self.database['node'], self.ports['REQ'])
44-
self.sockets['REQ'].connect(url)
45-
46-
def __getitem__(self, key):
47-
"""For direct indexing the MMIClient object as a dict"""
48-
return self.database[key]
49-
50-
# from here: BMI commands that gets translated to MMI.
15+
self.socket = context.socket(zmq.REQ)
16+
self.socket.connect(zmq_address)
17+
18+
# from here: BMI commands that get translated to MMI.
5119
def initialize(self, configfile=None):
5220
"""
21+
Initialize the module
5322
"""
54-
pass
23+
24+
method = "initialize"
25+
26+
A = None
27+
metadata = {method : configfile}
28+
29+
send_array(self.socket, A, metadata)
30+
A, metadata = recv_array(self.socket)
5531

5632
def finalize(self):
5733
"""
34+
Finalize the module
5835
"""
59-
pass
36+
37+
method = "finalize"
38+
39+
A = None
40+
metadata = {method : -1}
6041

61-
def update(self, dt=-1):
62-
"""
63-
"""
64-
metadata = {'update': dt}
65-
send_array(self.sockets['REQ'], None, metadata=metadata)
66-
arr, result_meta = recv_array(self.sockets['REQ'])
67-
return result_meta['dt']
42+
send_array(self.socket, A, metadata)
43+
A, metadata = recv_array(self.socket)
6844

6945
def get_var_count(self):
7046
"""
47+
Return number of variables
7148
"""
72-
pass
49+
50+
method = "get_var_count"
51+
52+
A = None
53+
metadata = {method : -1}
54+
55+
send_array(self.socket, A, metadata)
56+
A, metadata = recv_array(self.socket)
57+
58+
return metadata[method]
7359

7460
def get_var_name(self, i):
75-
pass
61+
"""
62+
Return variable name
63+
"""
64+
65+
method = "get_var_name"
66+
67+
A = None
68+
metadata = {method : i}
69+
70+
send_array(self.socket, A, metadata)
71+
A, metadata = recv_array(self.socket)
72+
73+
return metadata[method]
7674

7775
def get_var_type(self, name):
78-
metadata = {'get_var_type': name}
79-
send_array(self.sockets['REQ'], None, metadata=metadata)
80-
arr, result_meta = recv_array(self.sockets['REQ'])
81-
return result_meta['get_var_type']
76+
"""
77+
Return variable name
78+
"""
8279

83-
def inq_compound(self, name):
84-
pass
80+
method = "get_var_type"
81+
82+
A = None
83+
metadata = {method : name}
8584

86-
def inq_compound_field(self, name, index):
87-
pass
85+
send_array(self.socket, A, metadata)
86+
A, metadata = recv_array(self.socket)
8887

89-
def make_compound_ctype(self, varname):
90-
pass
88+
return metadata[method]
9189

9290
def get_var_rank(self, name):
93-
metadata = {'get_var_rank': name}
94-
send_array(self.sockets['REQ'], None, metadata=metadata)
95-
arr, result_meta = recv_array(self.sockets['REQ'])
96-
return int(result_meta['get_var_rank'])
91+
"""
92+
Return variable rank
93+
"""
94+
95+
method = "get_var_rank"
96+
97+
A = None
98+
metadata = {method : name}
99+
100+
send_array(self.socket, A, metadata)
101+
A, metadata = recv_array(self.socket)
102+
103+
return metadata[method]
97104

98105
def get_var_shape(self, name):
99-
logger.debug('get_var_shape')
100-
metadata = {'get_var_shape': name}
101-
send_array(self.sockets['REQ'], None, metadata=metadata)
102-
arr, result_meta = recv_array(self.sockets['REQ'])
103-
return tuple(result_meta['get_var_shape'])
106+
"""
107+
Return variable shape
108+
"""
109+
110+
method = "get_var_shape"
111+
112+
A = None
113+
metadata = {method : rank}
114+
115+
send_array(self.socket, A, metadata)
116+
A, metadata = recv_array(self.socket)
117+
118+
return metadata[method]
119+
120+
def get_var(self, name):
121+
"""
122+
Return an nd array from model library
123+
"""
124+
125+
method = "get_var"
126+
127+
A = None
128+
metadata = {method : name}
129+
130+
send_array(self.socket, A, metadata)
131+
A, metadata = recv_array(self.socket)
132+
133+
return A
134+
135+
def set_var(self, name, var):
136+
"""
137+
Set the variable name with the values of var
138+
"""
139+
140+
method = "set_var"
141+
142+
A = var
143+
metadata = {method : name}
144+
145+
send_array(self.socket, A, metadata)
146+
A, metadata = recv_array(self.socket)
104147

105148
def get_start_time(self):
106-
metadata = {'get_start_time': None}
107-
send_array(self.sockets['REQ'], None, metadata=metadata)
108-
arr, result_meta = recv_array(self.sockets['REQ'])
109-
return float(result_meta['get_start_time'])
149+
"""
150+
Return start time
151+
"""
152+
153+
method = "get_start_time"
154+
155+
A = None
156+
metadata = {method : -1}
157+
158+
send_array(self.socket, A, metadata)
159+
A, metadata = recv_array(self.socket)
160+
161+
return metadata[method]
110162

111163
def get_end_time(self):
112-
metadata = {'get_end_time': None}
113-
send_array(self.sockets['REQ'], None, metadata=metadata)
114-
arr, result_meta = recv_array(self.sockets['REQ'])
115-
return float(result_meta['get_end_time'])
164+
"""
165+
Return end time of simulation
166+
"""
167+
168+
method = "get_end_time"
169+
170+
A = None
171+
metadata = {method : -1}
172+
173+
send_array(self.socket, A, metadata)
174+
A, metadata = recv_array(self.socket)
175+
176+
return metadata[method]
116177

117178
def get_current_time(self):
118-
metadata = {'get_current_time': None}
119-
send_array(self.sockets['REQ'], None, metadata=metadata)
120-
arr, result_meta = recv_array(self.sockets['REQ'])
121-
return float(result_meta['get_current_time'])
179+
"""
180+
Return current time of simulation
181+
"""
122182

123-
def get_time_step(self):
124-
pass
183+
method = "get_current_time"
184+
185+
A = None
186+
metadata = {method : -1}
125187

126-
def get_var(self, name):
127-
metadata = {'get_var': name}
128-
send_array(self.sockets['REQ'], None, metadata=metadata)
129-
arr, result_meta = recv_array(self.sockets['REQ'])
130-
return arr
188+
send_array(self.socket, A, metadata)
189+
A, metadata = recv_array(self.socket)
131190

132-
def set_var(self, name, var):
133-
pass
191+
return metadata[method]
134192

135-
def set_var_slice(self, name, start, count, var):
136-
pass
193+
def update(self, dt):
194+
"""
195+
Advance the module with timestep dt
196+
"""
137197

138-
def set_var_index(self, name, index, var):
139-
pass
198+
method = "update"
140199

141-
def set_structure_field(self, name, id, field, value):
142-
pass
200+
A = None
201+
metadata = {method : dt}
143202

144-
def set_logger(self, logger):
145-
pass
203+
send_array(self.socket, A, metadata)
204+
A, metadata = recv_array(self.socket)
146205

147-
def __enter__(self):
206+
# TODO: Do we really need these two?
207+
def inq_compound(self, name):
208+
"""
209+
Return the number of fields of a compound type.
210+
"""
148211
pass
149212

150-
def __exit__(self, type, value, tb):
151-
pass
213+
def inq_compound_field(self, name, index):
214+
"""
215+
Lookup the type,rank and shape of a compound field
216+
"""
217+
pass

mmi/runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def process_incoming(model, sockets, data):
184184
# arr[S] = data
185185
# elif action['operator'] == 'add':
186186
# arr[S] += data
187+
elif "initialize" in metadata:
188+
config_file = metadata["initialize"]
189+
model.initialize(config_file)
190+
elif "finalize" in metadata:
191+
model.finalize()
187192
else:
188193
logger.warn("got unknown message {} from socket {}".format(str(metadata), sock))
189194
if sock.socket_type == zmq.REP:

mmi/tracker_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,10 @@ def key_occurrence(self, key, update=True):
2828
if key in v:
2929
result[str(v[key])] = k
3030
return result
31+
32+
def zmq_address(self, key):
33+
"""
34+
Return a ZeroMQ address to the module with the provided key.
35+
"""
36+
zmq_address = "tcp://" + self.database[key]['node'] + ":" + str(self.database[key]['ports']['REQ'])
37+
return zmq_address

0 commit comments

Comments
 (0)