Skip to content

Commit 9ab0396

Browse files
committed
Passing the countinference and secret successfully
1 parent 3348291 commit 9ab0396

File tree

6 files changed

+24
-8
lines changed

6 files changed

+24
-8
lines changed

inference/core/managers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def add_model(
8585
logger.debug("ModelManager - model initialisation...")
8686

8787
try:
88-
model_class = self.model_registry.get_model(resolved_identifier, api_key)
88+
model_class = self.model_registry.get_model(resolved_identifier, api_key, countinference=countinference, service_secret=service_secret)
8989
model = model_class(
9090
model_id=model_id,
9191
api_key=api_key,

inference/core/managers/decorators/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def add_model(
5656
api_key: str,
5757
model_id_alias: Optional[str] = None,
5858
endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
59+
countinference: bool = None,
60+
service_secret: str = None,
5961
):
6062
"""Adds a model to the manager.
6163
@@ -71,6 +73,8 @@ def add_model(
7173
api_key,
7274
model_id_alias=model_id_alias,
7375
endpoint_type=endpoint_type,
76+
countinference=countinference,
77+
service_secret=service_secret,
7478
)
7579

7680
async def infer_from_request(

inference/core/managers/decorators/fixed_size_cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def add_model(
3838
api_key: str,
3939
model_id_alias: Optional[str] = None,
4040
endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
41+
countinference: bool = None,
42+
service_secret: str = None,
4143
) -> None:
4244
"""Adds a model to the manager and evicts the least recently used if the cache is full.
4345
@@ -95,6 +97,8 @@ def add_model(
9597
api_key,
9698
model_id_alias=model_id_alias,
9799
endpoint_type=endpoint_type,
100+
countinference=countinference,
101+
service_secret=service_secret,
98102
)
99103
except Exception as error:
100104
logger.debug(

inference/core/registries/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, registry_dict) -> None:
1717
"""
1818
self.registry_dict = registry_dict
1919

20-
def get_model(self, model_type: str, model_id: str) -> Model:
20+
def get_model(self, model_type: str, model_id: str, countinference: bool = None, service_secret: str = None) -> Model:
2121
"""Returns the model class based on the given model type.
2222
2323
Args:

inference/core/registries/roboflow.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class RoboflowModelRegistry(ModelRegistry):
6868
then returns a model class based on the model type.
6969
"""
7070

71-
def get_model(self, model_id: ModelID, api_key: str) -> Model:
71+
def get_model(self, model_id: ModelID, api_key: str, countinference: bool = None, service_secret: str = None) -> Model:
7272
"""Returns the model class based on the given model id and API key.
7373
7474
Args:
@@ -81,7 +81,7 @@ def get_model(self, model_id: ModelID, api_key: str) -> Model:
8181
Raises:
8282
ModelNotRecognisedError: If the model type is not supported or found.
8383
"""
84-
model_type = get_model_type(model_id, api_key)
84+
model_type = get_model_type(model_id, api_key, countinference=countinference, service_secret=service_secret)
8585
logger.debug(f"Model type: {model_type}")
8686

8787
if model_type not in self.registry_dict:
@@ -120,6 +120,8 @@ def _check_if_api_key_has_access_to_model(
120120
def get_model_type(
121121
model_id: ModelID,
122122
api_key: Optional[str] = None,
123+
countinference: Optional[bool] = None,
124+
service_secret: Optional[str] = None,
123125
) -> Tuple[TaskType, ModelType]:
124126
"""Retrieves the model type based on the given model ID and API key.
125127
@@ -179,6 +181,8 @@ def get_model_type(
179181
api_data = get_roboflow_model_data(
180182
api_key=api_key,
181183
model_id=model_id,
184+
countinference=countinference,
185+
service_secret=service_secret,
182186
endpoint_type=ModelEndpointType.ORT,
183187
device_id=GLOBAL_DEVICE_ID,
184188
).get("ort")
@@ -187,6 +191,8 @@ def get_model_type(
187191
api_data = get_roboflow_instant_model_data(
188192
api_key=api_key,
189193
model_id=model_id,
194+
countinference=countinference,
195+
service_secret=service_secret,
190196
)
191197
project_task_type = api_data.get("taskType", "object-detection")
192198
if api_data is None:

inference/core/roboflow_api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,14 +261,15 @@ def get_roboflow_model_data(
261261
if INTERNAL_WEIGHTS_URL_SUFFIX == "serverless":
262262
if countinference is not None:
263263
params.append(("countinference", str(countinference).lower()))
264-
if service_secret is not None:
265-
params.append(("service_secret", service_secret))
264+
if service_secret is not None:
265+
params.append(("service_secret", service_secret))
266266

267267
api_base_url = urllib.parse.urljoin(API_BASE_URL, INTERNAL_WEIGHTS_URL_SUFFIX)
268268
api_url = _add_params_to_url(
269269
url=f"{api_base_url}/{endpoint_type.value}/{model_id}",
270270
params=params,
271271
)
272+
logger.debug(f"Fetching model data from Roboflow API with URL: {api_url}.")
272273
api_data = _get_from_url(url=api_url)
273274
cache.set(
274275
api_data_cache_key,
@@ -307,14 +308,15 @@ def get_roboflow_instant_model_data(
307308
if INTERNAL_WEIGHTS_URL_SUFFIX == "serverless":
308309
if countinference is not None:
309310
params.append(("countinference", str(countinference).lower()))
310-
if service_secret is not None:
311-
params.append(("service_secret", service_secret))
311+
if service_secret is not None:
312+
params.append(("service_secret", service_secret))
312313

313314
api_base_url = urllib.parse.urljoin(API_BASE_URL, INTERNAL_WEIGHTS_URL_SUFFIX)
314315
api_url = _add_params_to_url(
315316
url=f"{api_base_url}/getWeights",
316317
params=params,
317318
)
319+
logger.debug(f"Fetching instant model data from Roboflow API with URL: {api_url}.")
318320
api_data = _get_from_url(url=api_url)
319321
cache.set(
320322
api_data_cache_key,

0 commit comments

Comments
 (0)