11import multiprocessing as _mp
2- import os
32import pathlib
43import tempfile
54import uuid
65from concurrent .futures import Future
76from multiprocessing .connection import Connection
8- from typing import List , Optional , Tuple
7+ from typing import Dict , Iterator , List , Optional , Set , Tuple
98
10- import numpy
9+ import numpy as np
1110from bioimageio .core import load_resource_description
1211from bioimageio .core .prediction_pipeline import PredictionPipeline , create_prediction_pipeline
12+ from bioimageio .core .resource_io import nodes
13+ from bioimageio .core .resource_io .nodes import ParametrizedInputShape
1314
1415from tiktorch import log
1516from tiktorch .rpc import Shutdown
1617from tiktorch .rpc import mp as _mp_rpc
1718from tiktorch .rpc .mp import MPServer
1819
20+ from ...converters import Tensor
1921from .backend import base
2022from .rpc_interface import IRPCModelSession
2123
2224
23- class ModelSessionProcess (IRPCModelSession ):
24- def __init__ (self , model_zip : bytes , devices : List [str ]) -> None :
25- _tmp_file = tempfile .NamedTemporaryFile (suffix = ".zip" , delete = False )
26- _tmp_file .write (model_zip )
27- _tmp_file .close ()
28- model = load_resource_description (pathlib .Path (_tmp_file .name ))
29- os .unlink (_tmp_file .name )
30- self ._model : PredictionPipeline = create_prediction_pipeline (bioimageio_model = model , devices = devices )
25+ class ShapeValidator :
26+ def __init__ (self , model : PredictionPipeline ):
27+ self ._model = model
28+
29+ def check_tensors (self , tensors : Set [Tensor ]):
30+ for tensor in tensors :
31+ axes_wih_size = self ._get_axes_with_size (tensor .data .dims , tensor .data .shape )
32+ self ._check_shape (tensor .spec_id , axes_wih_size )
33+
34+ def _check_shape (self , spec_id : str , shape : Dict [str , int ]):
35+ spec = self ._get_input_spec (spec_id )
36+ if isinstance (spec .shape , list ):
37+ self ._check_shape_explicit (spec , shape )
38+ elif isinstance (spec .shape , ParametrizedInputShape ):
39+ self ._check_shape_parameterized (spec , shape )
40+ else :
41+ raise ValueError (f"Unexpected shape { spec .shape } " )
42+
43+ def _get_input_spec (self , spec_id : str ) -> nodes .InputTensor :
44+ self ._check_spec_exists (spec_id )
45+ specs = [spec for spec in self ._model .input_specs if spec .name == spec_id ]
46+ assert len (specs ) == 1 , "ids of tensor specs should be unique"
47+ return specs [0 ]
48+
49+ def _check_spec_exists (self , spec_id : str ):
50+ spec_names = [spec .name for spec in self ._model .input_specs ]
51+ if spec_id not in spec_names :
52+ raise ValueError (f"Spec { spec_id } doesn't exist for specs { spec_names } " )
53+
54+ def _check_shape_explicit (self , spec : nodes .InputTensor , tensor_shape : Dict [str , int ]):
55+ assert self ._is_shape_explicit (spec )
56+ reference_shape = {name : size for name , size in zip (spec .axes , spec .shape )}
57+ if reference_shape != tensor_shape :
58+ raise ValueError (f"Incompatible shapes found { tensor_shape } , expected { reference_shape } " )
59+
60+ def _check_shape_parameterized (self , spec : nodes .InputTensor , tensor_shape : Dict [str , int ]):
61+ assert isinstance (spec .shape , ParametrizedInputShape )
62+ if not self ._is_shape (tensor_shape .values ()):
63+ raise ValueError (f"Invalid shape's sizes { tensor_shape } " )
64+
65+ min_shape = self ._get_axes_with_size (spec .axes , tuple (spec .shape .min ))
66+ step = self ._get_axes_with_size (spec .axes , tuple (spec .shape .step ))
67+ assert min_shape .keys () == step .keys ()
68+ if tensor_shape .keys () != min_shape .keys ():
69+ raise ValueError (f"Incompatible axes for tensor { tensor_shape } and spec { spec } " )
70+
71+ tensor_shapes_arr = np .array (list (tensor_shape .values ()))
72+ min_shape_arr = np .array (list (min_shape .values ()))
73+ step_arr = np .array (list (step .values ()))
74+ diff = tensor_shapes_arr - min_shape_arr
75+ if any (size < 0 for size in diff ):
76+ raise ValueError (f"Tensor shape { tensor_shape } smaller than min shape { min_shape } " )
77+
78+ non_zero_idx = np .nonzero (step_arr )
79+ multipliers = diff [non_zero_idx ] / step_arr [non_zero_idx ]
80+ multiplier = np .unique (multipliers )
81+ if len (multiplier ) == 1 and self ._is_natural_number (multiplier [0 ]):
82+ return
83+ raise ValueError (f"Tensor shape { tensor_shape } not valid for spec { spec } " )
84+
85+ def _is_natural_number (self , n ) -> bool :
86+ return np .floor (n ) == np .ceil (n ) and n >= 0
87+
88+ def _is_shape (self , shape : Iterator [int ]) -> bool :
89+ return all (self ._is_natural_number (dim ) for dim in shape )
90+
91+ def _get_axes_with_size (self , axes : Tuple [str , ...], shape : Tuple [int , ...]) -> Dict [str , int ]:
92+ assert len (axes ) == len (shape )
93+ return {name : size for name , size in zip (axes , shape )}
94+
95+ def _is_shape_explicit (self , spec : nodes .InputTensor ) -> bool :
96+ return isinstance (spec .shape , list )
97+
98+
99+ class ModelSessionProcess (IRPCModelSession [PredictionPipeline ]):
100+ def __init__ (self , model : PredictionPipeline ) -> None :
101+ super ().__init__ (model )
31102 self ._datasets = {}
32103 self ._worker = base .SessionBackend (self ._model )
104+ self ._shape_validator = ShapeValidator (self ._model )
33105
34- def forward (self , input_tensors : numpy .ndarray ) -> Future :
35- res = self ._worker .forward (input_tensors )
106+ def forward (self , input_tensors : Set [Tensor ]) -> Future :
107+ self ._shape_validator .check_tensors (input_tensors )
108+ tensors_data = [tensor .data for tensor in input_tensors ]
109+ res = self ._worker .forward (tensors_data )
36110 return res
37111
38112 def create_dataset (self , mean , stddev ):
@@ -46,7 +120,7 @@ def shutdown(self) -> Shutdown:
46120
47121
48122def _run_model_session_process (
49- conn : Connection , model_zip : bytes , devices : List [ str ] , log_queue : Optional [_mp .Queue ] = None
123+ conn : Connection , prediction_pipeline : PredictionPipeline , log_queue : Optional [_mp .Queue ] = None
50124):
51125 try :
52126 # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
@@ -60,7 +134,7 @@ def _run_model_session_process(
60134 if log_queue :
61135 log .configure (log_queue )
62136
63- session_proc = ModelSessionProcess (model_zip , devices )
137+ session_proc = ModelSessionProcess (prediction_pipeline )
64138 srv = MPServer (session_proc , conn )
65139 srv .listen ()
66140
@@ -69,10 +143,26 @@ def start_model_session_process(
69143 model_zip : bytes , devices : List [str ], log_queue : Optional [_mp .Queue ] = None
70144) -> Tuple [_mp .Process , IRPCModelSession ]:
71145 client_conn , server_conn = _mp .Pipe ()
146+ prediction_pipeline = _get_prediction_pipeline_from_model_bytes (model_zip , devices )
72147 proc = _mp .Process (
73148 target = _run_model_session_process ,
74149 name = "ModelSessionProcess" ,
75- kwargs = {"conn" : server_conn , "devices" : devices , "log_queue" : log_queue , "model_zip" : model_zip },
150+ kwargs = {
151+ "conn" : server_conn ,
152+ "log_queue" : log_queue ,
153+ "prediction_pipeline" : prediction_pipeline ,
154+ },
76155 )
77156 proc .start ()
78- return proc , _mp_rpc .create_client (IRPCModelSession , client_conn )
157+ # here create the prediction pipeline, share it to the model session class and the client
158+ return proc , _mp_rpc .create_client (
159+ iface_cls = IRPCModelSession , api_kwargs = {"model" : prediction_pipeline }, conn = client_conn
160+ )
161+
162+
163+ def _get_prediction_pipeline_from_model_bytes (model_zip : bytes , devices : List [str ]) -> PredictionPipeline :
164+ with tempfile .NamedTemporaryFile (suffix = ".zip" , delete = False ) as _tmp_file :
165+ _tmp_file .write (model_zip )
166+ temp_file_path = pathlib .Path (_tmp_file .name )
167+ model = load_resource_description (temp_file_path )
168+ return create_prediction_pipeline (bioimageio_model = model , devices = devices )
0 commit comments