Skip to content

Commit 08c128c

Browse files
committed
Revert "stop-gap concurrency control solution: separate on CDF project"
This reverts commit 56de84d.
1 parent 06d6735 commit 08c128c

8 files changed

Lines changed: 23 additions & 51 deletions

File tree

cognite/client/_api/data_modeling/instances.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def __init__(self, config: ClientConfig, api_version: str | None, cognite_client
170170
super().__init__(config, api_version, cognite_client)
171171
self._AGGREGATE_LIMIT = 1000
172172
self._SEARCH_LIMIT = 1000
173+
self.__dm_semaphore = get_global_data_modeling_semaphore()
173174

174175
self._warn_on_alpha_debug_settings = FeaturePreviewWarning(
175176
api_maturity="alpha",
@@ -297,7 +298,7 @@ async def __call__(
297298
filter=filter.dump(camel_case_property=False) if isinstance(filter, Filter) else filter,
298299
other_params=other_params,
299300
headers=headers,
300-
semaphore=get_global_data_modeling_semaphore(self._config.project),
301+
semaphore=self.__dm_semaphore,
301302
):
302303
yield item
303304
return
@@ -310,7 +311,7 @@ async def __call__(
310311
filter=filter.dump(camel_case_property=False) if isinstance(filter, Filter) else filter,
311312
other_params=other_params,
312313
headers=headers,
313-
semaphore=get_global_data_modeling_semaphore(self._config.project),
314+
semaphore=self.__dm_semaphore,
314315
):
315316
yield list_cls._load_raw_api_response([raw])
316317

@@ -634,7 +635,7 @@ def _load_raw_api_response(cls, responses: list[dict[str, Any]]) -> _NodeOrEdgeL
634635
identifiers=identifiers,
635636
other_params=other_params,
636637
settings_forcing_raw_response_loading=[f"{include_typing=}"] if include_typing else None,
637-
semaphore=get_global_data_modeling_semaphore(self._config.project),
638+
semaphore=self.__dm_semaphore,
638639
)
639640

640641
return InstancesResult[T_Node, T_Edge](
@@ -722,7 +723,7 @@ async def delete(
722723
identifiers,
723724
wrap_ids=True,
724725
returns_items=True,
725-
semaphore=get_global_data_modeling_semaphore(self._config.project),
726+
semaphore=self.__dm_semaphore,
726727
),
727728
)
728729
node_ids = [NodeId.load(item) for item in deleted_instances if item["instanceType"] == "node"]
@@ -788,7 +789,7 @@ async def inspect(
788789
response = await self._post(
789790
self._RESOURCE_PATH + "/inspect",
790791
json={"items": chunk.as_dicts(), "inspectionOperations": inspect_operations},
791-
semaphore=get_global_data_modeling_semaphore(self._config.project),
792+
semaphore=self.__dm_semaphore,
792793
)
793794
items.extend(unpack_items(response))
794795

@@ -1067,7 +1068,7 @@ async def apply(
10671068
resource_cls=_NodeOrEdgeApplyResultAdapter, # type: ignore[type-var]
10681069
extra_body_fields=other_parameters,
10691070
input_resource_cls=_NodeOrEdgeApplyAdapter, # type: ignore[arg-type]
1070-
semaphore=get_global_data_modeling_semaphore(self._config.project),
1071+
semaphore=self.__dm_semaphore,
10711072
)
10721073
return InstancesApplyResult(
10731074
nodes=NodeApplyResultList([item for item in res if isinstance(item, NodeApplyResult)]),
@@ -1257,11 +1258,7 @@ async def search(
12571258
raise ValueError("nulls_first argument is not supported when sorting on instance search")
12581259
body["sort"] = [self._dump_instance_sort(s) for s in sorts]
12591260

1260-
res = await self._post(
1261-
url_path=self._RESOURCE_PATH + "/search",
1262-
json=body,
1263-
semaphore=get_global_data_modeling_semaphore(self._config.project),
1264-
)
1261+
res = await self._post(url_path=self._RESOURCE_PATH + "/search", json=body, semaphore=self.__dm_semaphore)
12651262
result = res.json()
12661263
return list_cls(
12671264
[resource_cls._load(item) for item in result["items"]], # type: ignore [misc]
@@ -1386,11 +1383,7 @@ async def aggregate(
13861383
if target_units:
13871384
body["targetUnits"] = [unit.dump(camel_case=True) for unit in target_units]
13881385

1389-
res = await self._post(
1390-
url_path=self._RESOURCE_PATH + "/aggregate",
1391-
json=body,
1392-
semaphore=get_global_data_modeling_semaphore(self._config.project),
1393-
)
1386+
res = await self._post(url_path=self._RESOURCE_PATH + "/aggregate", json=body, semaphore=self.__dm_semaphore)
13941387
result_list = InstanceAggregationResultList._load(res.json()["items"])
13951388
if group_by is not None:
13961389
return result_list
@@ -1499,11 +1492,7 @@ async def histogram(
14991492
if target_units:
15001493
body["targetUnits"] = [unit.dump(camel_case=True) for unit in target_units]
15011494

1502-
res = await self._post(
1503-
url_path=self._RESOURCE_PATH + "/aggregate",
1504-
json=body,
1505-
semaphore=get_global_data_modeling_semaphore(self._config.project),
1506-
)
1495+
res = await self._post(url_path=self._RESOURCE_PATH + "/aggregate", json=body, semaphore=self.__dm_semaphore)
15071496
if is_singleton:
15081497
return HistogramValue.load(res.json()["items"][0]["aggregates"][0])
15091498
else:
@@ -1662,10 +1651,7 @@ async def _query_or_sync(
16621651
headers = {"cdf-version": f"{self._config.api_subversion}-alpha"}
16631652

16641653
response = await self._post(
1665-
url_path=self._RESOURCE_PATH + f"/{endpoint}",
1666-
json=body,
1667-
headers=headers,
1668-
semaphore=get_global_data_modeling_semaphore(self._config.project),
1654+
url_path=self._RESOURCE_PATH + f"/{endpoint}", json=body, headers=headers, semaphore=self.__dm_semaphore
16691655
)
16701656
json_payload = response.json()
16711657
default_by_reference = query.instance_type_by_result_expression()
@@ -1865,7 +1851,7 @@ async def list(
18651851
other_params=other_params,
18661852
settings_forcing_raw_response_loading=settings_forcing_raw_response_loading,
18671853
headers=headers,
1868-
semaphore=get_global_data_modeling_semaphore(self._config.project),
1854+
semaphore=self.__dm_semaphore,
18691855
),
18701856
)
18711857

cognite/client/_api/datapoints.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def _request_datapoints(self, payload: dict[str, Any]) -> Sequence[DataPoi
112112
f"{self.dps_client._RESOURCE_PATH}/list",
113113
json=payload,
114114
headers={"accept": "application/protobuf"},
115-
semaphore=get_global_datapoints_semaphore(self.dps_client._config.project),
115+
semaphore=get_global_datapoints_semaphore(),
116116
)
117117
).content
118118
)
@@ -388,7 +388,7 @@ async def _queue_new_subtasks(
388388
self,
389389
futures_dct: dict[asyncio.Task, list[BaseDpsFetchSubtask]],
390390
) -> None:
391-
sem = get_global_datapoints_semaphore(self.dps_client._config.project)
391+
sem = get_global_datapoints_semaphore()
392392

393393
# This may seem silly (to ask the semaphore for unused capacity), but the logic is sound:
394394
# we want to combine subtasks into requests *as late as possible* for optimal chunking.
@@ -2222,7 +2222,7 @@ async def insert_dataframe(self, df: pd.DataFrame, dropna: bool = True) -> None:
22222222
await self.insert_multiple(dps) # type: ignore[arg-type]
22232223

22242224
def _select_dps_fetch_strategy(self, queries: list[DatapointsQuery]) -> type[DpsFetchStrategy]:
2225-
semaphore = get_global_datapoints_semaphore(self._config.project)
2225+
semaphore = get_global_datapoints_semaphore()
22262226

22272227
# We decide the fetching strategy based on how many time series the user has requested VS the
22282228
# max concurrency we allow for datapoints requests. When the number of time series is small enough
@@ -2362,7 +2362,7 @@ async def _insert_datapoints(self, payload: list[dict[str, Any]]) -> None:
23622362
url_path=self.dps_client._RESOURCE_PATH,
23632363
json={"items": payload},
23642364
headers=headers,
2365-
semaphore=get_global_datapoints_semaphore(self.dps_client._config.project),
2365+
semaphore=get_global_datapoints_semaphore(),
23662366
)
23672367
for dct in payload:
23682368
dct["datapoints"].clear()

cognite/client/_http_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
from cognite.client.response import CogniteHTTPResponse
3131
from cognite.client.utils._concurrency import get_global_semaphore
32-
from cognite.client.utils._url import extract_project_from_url
3332

3433
logger = logging.getLogger(__name__)
3534

@@ -266,7 +265,7 @@ async def _with_retry(
266265
if semaphore is None:
267266
# By default, we run with a semaphore decided by user settings of 'max_workers' in 'global_config'.
268267
# Since the user can run any number of SDK tasks concurrently, this needs to be global:
269-
semaphore = get_global_semaphore(extract_project_from_url(url))
268+
semaphore = get_global_semaphore()
270269

271270
is_auto_retryable = False
272271
retry_tracker = RetryTracker(url, self.config)

cognite/client/_sync_api/data_modeling/instances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
===============================================================================
3-
9da25b05ee67f4f54dca5c516c7b71ec
3+
f96b1d5b543aefd8aff5e1e8a645a614
44
This file is auto-generated from the Async API modules, - do not edit manually!
55
===============================================================================
66
"""

cognite/client/_sync_api/datapoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
===============================================================================
3-
8876f83ce494d0a44fc1e354625554ca
3+
8b1cd40ee0ab9406c9edf75c089f78fa
44
This file is auto-generated from the Async API modules, - do not edit manually!
55
===============================================================================
66
"""

cognite/client/utils/_concurrency.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,22 @@
1818
from cognite.client.utils._auxiliary import no_op
1919

2020

21-
# We add the 'project' argument to make sure concurrency limits are applied per project:
2221
@cache
23-
def get_global_semaphore(project: str) -> asyncio.BoundedSemaphore:
22+
def get_global_semaphore() -> asyncio.BoundedSemaphore:
2423
from cognite.client import global_config
2524

2625
return asyncio.BoundedSemaphore(global_config.max_workers)
2726

2827

2928
@cache
30-
def get_global_datapoints_semaphore(project: str) -> asyncio.BoundedSemaphore:
29+
def get_global_datapoints_semaphore() -> asyncio.BoundedSemaphore:
3130
from cognite.client import global_config
3231

3332
return asyncio.BoundedSemaphore(global_config.max_workers)
3433

3534

3635
@cache
37-
def get_global_data_modeling_semaphore(project: str) -> asyncio.BoundedSemaphore:
36+
def get_global_data_modeling_semaphore() -> asyncio.BoundedSemaphore:
3837
return asyncio.BoundedSemaphore(2)
3938

4039

cognite/client/utils/_url.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import re
4-
import warnings
54
from typing import TYPE_CHECKING, Any
65
from urllib.parse import quote
76

@@ -65,17 +64,6 @@
6564
VALID_URL_PATTERN = re.compile(r"^https?://[a-z\d.:\-]+(?:/api/v1/projects/[^/]+)?((/[^\?]+)?(\?.+)?)")
6665
VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH"}
6766

68-
EXTRACT_PROJECT = re.compile(r"/api/v1/projects/([^/]+)")
69-
70-
71-
def extract_project_from_url(url: str, default: str = "") -> str:
72-
# TODO: Stop-gap solution while we await the final concurrency limit implementation
73-
try:
74-
return EXTRACT_PROJECT.search(url).group(1) # type: ignore[union-attr]
75-
except AttributeError:
76-
warnings.warn("No project found in URL", UserWarning)
77-
return default
78-
7967

8068
def resolve_url(api_client: BasicAsyncAPIClient, method: str, url_path: str) -> tuple[bool, str]:
8169
if not url_path.startswith("/"):

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def override_semaphore(new: int, target: Literal["basic", "datapoints", "data_mo
271271

272272
# The new semaphore should now pick up the changed max_workers value:
273273
semaphore_get_fn.cache_clear()
274-
sem = semaphore_get_fn("test_project") # TODO: Stop-gap solution awaiting final concurrency limit implementation
274+
sem = semaphore_get_fn()
275275
assert new == sem._value == sem._bound_value, "Semaphore didn't update according to overridden max_workers" # type: ignore[attr-defined]
276276

277277
try:

0 commit comments

Comments
 (0)