Skip to content

Commit fcf3cdb

Browse files
committed
Use bioimageio prediction pipeline as model from client and server
1 parent b1587c7 commit fcf3cdb

8 files changed

Lines changed: 200 additions & 43 deletions

File tree

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,16 @@ def bioimageio_dummy_model_filepath(data_path, tmpdir):
115115
@pytest.fixture
116116
def bioimageio_dummy_model_bytes(data_path):
117117
rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY / "Dummy.model.yaml"
118+
return _bioimageio_package(rdf_source)
119+
120+
121+
@pytest.fixture
122+
def bioimageio_dummy_param_model_bytes(data_path):
123+
rdf_source = data_path / "dummy_param" / "Dummy.model_param.yaml"
124+
return _bioimageio_package(rdf_source)
125+
126+
127+
def _bioimageio_package(rdf_source):
118128
data = io.BytesIO()
119129
export_resource_package(rdf_source, output_path=data)
120130
return data

tests/test_rpc/test_mp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def client(log_queue):
6464
p = mp.Process(target=_srv, args=(parent, log_queue))
6565
p.start()
6666

67-
client = create_client(ITestApi, child, timeout=10)
67+
client = create_client(iface_cls=ITestApi, conn=child, timeout=10)
6868

6969
yield client
7070

@@ -108,7 +108,7 @@ def __getattr__(self, name):
108108
p = mp.Process(target=_srv, args=(parent, log_queue))
109109
p.start()
110110

111-
client = create_client(ITestApi, SlowConn(child))
111+
client = create_client(iface_cls=ITestApi, conn=SlowConn(child))
112112

113113
client.fast_compute(2, 2)
114114

@@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue):
121121
p = mp.Process(target=_srv, args=(parent, log_queue))
122122
p.start()
123123

124-
client = create_client(ITestApi, child, timeout=0.001)
124+
client = create_client(iface_cls=ITestApi, conn=child, timeout=0.001)
125125

126126
with pytest.raises(TimeoutError):
127127
client.compute(1, 2)
@@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls):
256256
p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue))
257257
p.start()
258258

259-
data["client"] = client = create_client(iface_cls, child)
259+
data["client"] = client = create_client(iface_cls=iface_cls, conn=child)
260260
data["process"] = p
261261
return client
262262

tests/test_server/test_grpc/test_inference_servicer.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,25 +156,56 @@ def test_call_fails_with_unknown_model_session_id(self, grpc_stub):
156156

157157
def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes):
158158
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes))
159-
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
159+
arr = xr.DataArray(np.arange(128 * 128).reshape(1, 1, 128, 128), dims=("b", "c", "x", "y"))
160160
expected = arr + 1
161-
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
161+
input_spec_id = "input"
162+
output_spec_id = "output"
163+
input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)]
162164
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
163165

164166
grpc_stub.CloseModelSession(model)
165167

166168
assert len(res.tensors) == 1
169+
assert res.tensors[0].specId == output_spec_id
167170
assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))
168171

172+
def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_model_bytes):
173+
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes))
174+
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
175+
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
176+
with pytest.raises(grpc.RpcError):
177+
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
178+
grpc_stub.CloseModelSession(model)
179+
180+
@pytest.mark.parametrize("shape", [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)])
181+
def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes):
182+
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes))
183+
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y"))
184+
input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
185+
with pytest.raises(grpc.RpcError):
186+
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
187+
grpc_stub.CloseModelSession(model)
188+
189+
@pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)])
190+
def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes):
191+
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes))
192+
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y"))
193+
input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
194+
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
195+
grpc_stub.CloseModelSession(model)
196+
169197
@pytest.mark.skip
170198
def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_bytes):
171199
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_tensorflow_model_bytes))
172200
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
173201
expected = arr * -1
174-
input_tensors = [converters.xarray_to_pb_tensor(arr)]
202+
input_spec_id = "input"
203+
output_spec_id = "output"
204+
input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)]
175205
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
176206

177207
grpc_stub.CloseModelSession(model)
178208

179209
assert len(res.tensors) == 1
210+
assert res.tensors[0].specId == output_spec_id
180211
assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))

tiktorch/rpc/mp.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from functools import wraps
66
from multiprocessing.connection import Connection
77
from threading import Event, Thread
8-
from typing import Any, Optional, Type, TypeVar
8+
from typing import Any, Dict, Optional, Type, TypeVar
99
from uuid import uuid4
1010

1111
from .exceptions import Shutdown
@@ -72,9 +72,8 @@ def __call__(self, *args, **kwargs) -> Any:
7272
return self._client._invoke(self._method_name, *args, **kwargs)
7373

7474

75-
def create_client(iface_cls: Type[T], conn: Connection, timeout=None) -> T:
75+
def create_client(iface_cls: Type[T], conn: Connection, api_kwargs: Optional[Dict[str, any]] = None, timeout=None) -> T:
7676
client = MPClient(iface_cls.__name__, conn, timeout)
77-
get_exposed_methods(iface_cls)
7877

7978
def _make_method(method):
8079
class MethodWrapper:
@@ -98,12 +97,15 @@ def __call__(self, *args, **kwargs) -> Any:
9897
return MethodWrapper()
9998

10099
class _Client(iface_cls):
101-
pass
100+
def __init__(self, kwargs: Optional[Dict]):
101+
kwargs = kwargs or {}
102+
super().__init__(**kwargs)
102103

103-
for method_name, method in get_exposed_methods(iface_cls).items():
104+
exposed_methods = get_exposed_methods(iface_cls)
105+
for method_name, method in exposed_methods.items():
104106
setattr(_Client, method_name, _make_method(method))
105107

106-
return _Client()
108+
return _Client(api_kwargs)
107109

108110

109111
class MPClient:
@@ -190,7 +192,7 @@ def _shutdown(self, exc):
190192

191193
class Message:
192194
def __init__(self, id_):
193-
self.id = id
195+
self.id = id_
194196

195197

196198
class Signal:
@@ -200,20 +202,19 @@ def __init__(self, payload):
200202

201203
class MethodCall(Message):
202204
def __init__(self, id_, method_name, args, kwargs):
203-
self.id = id_
205+
super().__init__(id_)
204206
self.method_name = method_name
205207
self.args = args
206208
self.kwargs = kwargs
207209

208210

209211
class Cancellation(Message):
210-
def __init__(self, id_):
211-
self.id = id_
212+
pass
212213

213214

214215
class MethodReturn(Message):
215216
def __init__(self, id_, result: Result):
216-
self.id = id_
217+
super().__init__(id_)
217218
self.result = result
218219

219220

tiktorch/server/grpc/inference_servicer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tiktorch.proto import inference_pb2, inference_pb2_grpc
77
from tiktorch.server.data_store import IDataStore
88
from tiktorch.server.device_pool import DeviceStatus, IDevicePool
9-
from tiktorch.server.session.process import start_model_session_process
9+
from tiktorch.server.session.process import ShapeValidator, start_model_session_process
1010
from tiktorch.server.session_manager import ISession, SessionManager
1111

1212

@@ -36,7 +36,7 @@ def CreateModelSession(
3636
lease.terminate()
3737
raise
3838

39-
session = self.__session_manager.create_session()
39+
session = self.__session_manager.create_session(client)
4040
session.on_close(lease.terminate)
4141
session.on_close(client.shutdown)
4242

@@ -76,6 +76,8 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De
7676
def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse:
7777
session = self._getModelSession(context, request.modelSessionId)
7878
tensors = set([converters.pb_tensor_to_tensor(tensor) for tensor in request.tensors])
79+
shape_validator = ShapeValidator(session.client.model)
80+
shape_validator.check_tensors(tensors)
7981
res = session.client.forward(tensors)
8082
output_spec_ids = [spec.name for spec in session.client.model.output_specs]
8183
assert len(output_spec_ids) == len(res)

tiktorch/server/session/process.py

Lines changed: 107 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,112 @@
11
import multiprocessing as _mp
2-
import os
32
import pathlib
43
import tempfile
54
import uuid
65
from concurrent.futures import Future
76
from multiprocessing.connection import Connection
8-
from typing import List, Optional, Tuple
7+
from typing import Dict, Iterator, List, Optional, Set, Tuple
98

10-
import numpy
9+
import numpy as np
1110
from bioimageio.core import load_resource_description
1211
from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
12+
from bioimageio.core.resource_io import nodes
13+
from bioimageio.core.resource_io.nodes import ParametrizedInputShape
1314

1415
from tiktorch import log
1516
from tiktorch.rpc import Shutdown
1617
from tiktorch.rpc import mp as _mp_rpc
1718
from tiktorch.rpc.mp import MPServer
1819

20+
from ...converters import Tensor
1921
from .backend import base
2022
from .rpc_interface import IRPCModelSession
2123

2224

23-
class ModelSessionProcess(IRPCModelSession):
24-
def __init__(self, model_zip: bytes, devices: List[str]) -> None:
25-
_tmp_file = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
26-
_tmp_file.write(model_zip)
27-
_tmp_file.close()
28-
model = load_resource_description(pathlib.Path(_tmp_file.name))
29-
os.unlink(_tmp_file.name)
30-
self._model: PredictionPipeline = create_prediction_pipeline(bioimageio_model=model, devices=devices)
25+
class ShapeValidator:
26+
def __init__(self, model: PredictionPipeline):
27+
self._model = model
28+
29+
def check_tensors(self, tensors: Set[Tensor]):
30+
for tensor in tensors:
31+
axes_wih_size = self._get_axes_with_size(tensor.data.dims, tensor.data.shape)
32+
self._check_shape(tensor.spec_id, axes_wih_size)
33+
34+
def _check_shape(self, spec_id: str, shape: Dict[str, int]):
35+
spec = self._get_input_spec(spec_id)
36+
if isinstance(spec.shape, list):
37+
self._check_shape_explicit(spec, shape)
38+
elif isinstance(spec.shape, ParametrizedInputShape):
39+
self._check_shape_parameterized(spec, shape)
40+
else:
41+
raise ValueError(f"Unexpected shape {spec.shape}")
42+
43+
def _get_input_spec(self, spec_id: str) -> nodes.InputTensor:
44+
self._check_spec_exists(spec_id)
45+
specs = [spec for spec in self._model.input_specs if spec.name == spec_id]
46+
assert len(specs) == 1, "ids of tensor specs should be unique"
47+
return specs[0]
48+
49+
def _check_spec_exists(self, spec_id: str):
50+
spec_names = [spec.name for spec in self._model.input_specs]
51+
if spec_id not in spec_names:
52+
raise ValueError(f"Spec {spec_id} doesn't exist for specs {spec_names}")
53+
54+
def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]):
55+
assert self._is_shape_explicit(spec)
56+
reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)}
57+
if reference_shape != tensor_shape:
58+
raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}")
59+
60+
def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]):
61+
assert isinstance(spec.shape, ParametrizedInputShape)
62+
if not self._is_shape(tensor_shape.values()):
63+
raise ValueError(f"Invalid shape's sizes {tensor_shape}")
64+
65+
min_shape = self._get_axes_with_size(spec.axes, tuple(spec.shape.min))
66+
step = self._get_axes_with_size(spec.axes, tuple(spec.shape.step))
67+
assert min_shape.keys() == step.keys()
68+
if tensor_shape.keys() != min_shape.keys():
69+
raise ValueError(f"Incompatible axes for tensor {tensor_shape} and spec {spec}")
70+
71+
tensor_shapes_arr = np.array(list(tensor_shape.values()))
72+
min_shape_arr = np.array(list(min_shape.values()))
73+
step_arr = np.array(list(step.values()))
74+
diff = tensor_shapes_arr - min_shape_arr
75+
if any(size < 0 for size in diff):
76+
raise ValueError(f"Tensor shape {tensor_shape} smaller than min shape {min_shape}")
77+
78+
non_zero_idx = np.nonzero(step_arr)
79+
multipliers = diff[non_zero_idx] / step_arr[non_zero_idx]
80+
multiplier = np.unique(multipliers)
81+
if len(multiplier) == 1 and self._is_natural_number(multiplier[0]):
82+
return
83+
raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}")
84+
85+
def _is_natural_number(self, n) -> bool:
86+
return np.floor(n) == np.ceil(n) and n >= 0
87+
88+
def _is_shape(self, shape: Iterator[int]) -> bool:
89+
return all(self._is_natural_number(dim) for dim in shape)
90+
91+
def _get_axes_with_size(self, axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]:
92+
assert len(axes) == len(shape)
93+
return {name: size for name, size in zip(axes, shape)}
94+
95+
def _is_shape_explicit(self, spec: nodes.InputTensor) -> bool:
96+
return isinstance(spec.shape, list)
97+
98+
99+
class ModelSessionProcess(IRPCModelSession[PredictionPipeline]):
100+
def __init__(self, model: PredictionPipeline) -> None:
101+
super().__init__(model)
31102
self._datasets = {}
32103
self._worker = base.SessionBackend(self._model)
104+
self._shape_validator = ShapeValidator(self._model)
33105

34-
def forward(self, input_tensors: numpy.ndarray) -> Future:
35-
res = self._worker.forward(input_tensors)
106+
def forward(self, input_tensors: Set[Tensor]) -> Future:
107+
self._shape_validator.check_tensors(input_tensors)
108+
tensors_data = [tensor.data for tensor in input_tensors]
109+
res = self._worker.forward(tensors_data)
36110
return res
37111

38112
def create_dataset(self, mean, stddev):
@@ -46,7 +120,7 @@ def shutdown(self) -> Shutdown:
46120

47121

48122
def _run_model_session_process(
49-
conn: Connection, model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
123+
conn: Connection, prediction_pipeline: PredictionPipeline, log_queue: Optional[_mp.Queue] = None
50124
):
51125
try:
52126
# from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
@@ -60,7 +134,7 @@ def _run_model_session_process(
60134
if log_queue:
61135
log.configure(log_queue)
62136

63-
session_proc = ModelSessionProcess(model_zip, devices)
137+
session_proc = ModelSessionProcess(prediction_pipeline)
64138
srv = MPServer(session_proc, conn)
65139
srv.listen()
66140

@@ -69,10 +143,26 @@ def start_model_session_process(
69143
model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
70144
) -> Tuple[_mp.Process, IRPCModelSession]:
71145
client_conn, server_conn = _mp.Pipe()
146+
prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices)
72147
proc = _mp.Process(
73148
target=_run_model_session_process,
74149
name="ModelSessionProcess",
75-
kwargs={"conn": server_conn, "devices": devices, "log_queue": log_queue, "model_zip": model_zip},
150+
kwargs={
151+
"conn": server_conn,
152+
"log_queue": log_queue,
153+
"prediction_pipeline": prediction_pipeline,
154+
},
76155
)
77156
proc.start()
78-
return proc, _mp_rpc.create_client(IRPCModelSession, client_conn)
157+
# here create the prediction pipeline, share it to the model session class and the client
158+
return proc, _mp_rpc.create_client(
159+
iface_cls=IRPCModelSession, api_kwargs={"model": prediction_pipeline}, conn=client_conn
160+
)
161+
162+
163+
def _get_prediction_pipeline_from_model_bytes(model_zip: bytes, devices: List[str]) -> PredictionPipeline:
164+
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file:
165+
_tmp_file.write(model_zip)
166+
temp_file_path = pathlib.Path(_tmp_file.name)
167+
model = load_resource_description(temp_file_path)
168+
return create_prediction_pipeline(bioimageio_model=model, devices=devices)

0 commit comments

Comments
 (0)