Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cognite_toolkit/_cdf_tk/client/_toolkit_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .api.extraction_pipelines import ExtractionPipelinesAPI
from .api.filemetadata import FileMetadataAPI
from .api.functions import FunctionsAPI
from .api.groups import GroupsAPI
from .api.hosted_extractors import HostedExtractorsAPI
from .api.infield import InfieldAPI
from .api.instances import InstancesAPI
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, http_client: HTTPClient, console: Console) -> None:
self.events = EventsAPI(http_client)
self.extraction_pipelines = ExtractionPipelinesAPI(http_client)
self.functions = FunctionsAPI(http_client)
self.groups = GroupsAPI(http_client)
self.hosted_extractors = HostedExtractorsAPI(http_client)
self.instances = InstancesAPI(http_client)
self.spaces = SpacesAPI(http_client)
Expand Down
56 changes: 11 additions & 45 deletions cognite_toolkit/_cdf_tk/client/api/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://api-docs.cognite.com/20230101/tag/Groups/operation/createGroups
"""

from collections.abc import Iterable, Sequence
from collections.abc import Sequence

from cognite_toolkit._cdf_tk.client.cdf_client import CDFResourceAPI, Endpoint, PagedResponse
from cognite_toolkit._cdf_tk.client.http_client import (
Expand Down Expand Up @@ -65,55 +65,21 @@ def delete(self, items: Sequence[InternalId]) -> None:
response = self._http_client.request_single_retries(request)
response.get_success_or_raise()

def paginate(
self,
all_groups: bool = False,
limit: int = 100,
cursor: str | None = None,
) -> PagedResponse[GroupResponse]:
"""Get a page of groups from CDF.

Args:
all_groups: Whether to return all groups (requires admin permissions).
limit: Maximum number of groups to return.
cursor: Cursor for pagination.

Returns:
PagedResponse of GroupResponse objects.
"""
return self._paginate(
cursor=cursor,
limit=limit,
params={"all": all_groups} if all_groups else None,
)

def iterate(
self,
all_groups: bool = False,
limit: int | None = None,
) -> Iterable[list[GroupResponse]]:
"""Iterate over all groups in CDF.

Args:
all_groups: Whether to return all groups (requires admin permissions).
limit: Maximum total number of groups to return.

Returns:
Iterable of lists of GroupResponse objects.
"""
return self._iterate(
limit=limit,
params={"all": all_groups} if all_groups else None,
)

def list(self, all_groups: bool = False, limit: int | None = None) -> list[GroupResponse]:
def list(self, all_groups: bool = False) -> list[GroupResponse]:
"""List all groups in CDF.

Args:
all_groups: Whether to return all groups (requires admin permissions).
limit: Maximum total number of groups to return.

Returns:
List of GroupResponse objects.
"""
return self._list(limit=limit, params={"all": all_groups} if all_groups else None)
endpoint = self._method_endpoint_map["list"]
response = self._http_client.request_single_retries(
RequestMessage(
endpoint_url=self._make_url(endpoint.path),
method=endpoint.method,
parameters={"all": all_groups},
)
).get_success_or_raise()
return self._validate_page_response(response).items
14 changes: 8 additions & 6 deletions cognite_toolkit/_cdf_tk/client/resource_classes/group/acls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Sequence
from typing import Annotated, Any, Literal, TypeAlias

from pydantic import BeforeValidator, Field, TypeAdapter, model_serializer, model_validator
from pydantic import BeforeValidator, Field, JsonValue, TypeAdapter, model_serializer, model_validator
from pydantic_core.core_schema import FieldSerializationInfo

from cognite_toolkit._cdf_tk.client._resource_base import BaseModelObject
Expand Down Expand Up @@ -550,16 +550,16 @@ class SimulatorsAcl(Acl):
"""ACL for Simulators resources."""

acl_name: Literal["simulatorsAcl"] = Field("simulatorsAcl", exclude=True)
actions: Sequence[Literal["READ", "WRITE"]]
actions: Sequence[Literal["READ", "WRITE", "DELETE", "RUN", "MANAGE"]]
scope: AllScope | DataSetScope


class UnknownAcl(Acl):
class UnknownAcl(BaseModelObject):
"""Fallback for unknown ACL types."""

acl_name: Literal["unknownAcl"] = Field("unknownAcl", exclude=True)
acl_name: str = Field("unknownAcl", exclude=True)
actions: Sequence[str]
scope: AllScope
scope: dict[str, JsonValue]


def _get_acl_name(cls: type[Acl]) -> str | None:
Expand All @@ -578,7 +578,9 @@ def _get_acl_name(cls: type[Acl]) -> str | None:


def _handle_unknown_acl(value: Any) -> Any:
if isinstance(value, dict) and isinstance(acl_name := value[ACL_NAME], str):
if isinstance(value, Acl | UnknownAcl):
return value
if isinstance(value, dict) and isinstance(acl_name := value.get(ACL_NAME), str):
acl_class = _KNOWN_ACLS.get(acl_name)
if acl_class:
return TypeAdapter(acl_class).validate_python(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Group(BaseModelObject):
capabilities: list[GroupCapability] | None = None
metadata: dict[str, str] | None = None
attributes: GroupAttributes | None = None
source_id: str | None = None
source_id: str | None = Field(None, coerce_numbers_to_str=True)
members: list[str] | Literal["allUserAccounts"] | None = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Annotated, Any, Literal, TypeAlias

from pydantic import BeforeValidator, Field, TypeAdapter
from pydantic import BeforeValidator, Field, TypeAdapter, field_validator

from cognite_toolkit._cdf_tk.client._resource_base import BaseModelObject
from cognite_toolkit._cdf_tk.client.resource_classes.group._constants import SCOPE_NAME
Expand Down Expand Up @@ -72,6 +72,28 @@ class TableScope(ScopeDefinition):
scope_name: Literal["tableScope"] = Field("tableScope", exclude=True)
dbs_to_tables: dict[str, list[str]]

@field_validator("dbs_to_tables", mode="before")
@classmethod
def standardize_format(cls, value: Any) -> dict[str, list[str]]:
"""The API returns the dbsToTables field in a different format
than the documentation specifies. This validator standardizes the format to match the documentation."""
if not isinstance(value, dict):
# Let pydantic handle the type error
return value
standardized: dict[str, list[str]] = {}
for db, tables in value.items():
if isinstance(tables, list):
standardized[db] = tables
elif isinstance(tables, dict) and "tables" in tables and isinstance(tables["tables"], list):
standardized[db] = tables["tables"]
elif tables == {}:
standardized[db] = []
else:
raise ValueError(
f"Invalid format for dbsToTables: expected dict[str, list[str]] or dict[str, dict[tables: list[str]]], got {type(tables).__name__} for db '{db}'"
)
return standardized


class ExtractionPipelineScope(ScopeDefinition):
"""Scope limited to specific extraction pipelines."""
Expand Down
2 changes: 2 additions & 0 deletions cognite_toolkit/_cdf_tk/client/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .api.filemetadata import FileMetadataAPI
from .api.function_schedules import FunctionSchedulesAPI
from .api.functions import FunctionsAPI
from .api.groups import GroupsAPI
from .api.hosted_extractor_destinations import HostedExtractorDestinationsAPI
from .api.hosted_extractor_jobs import HostedExtractorJobsAPI
from .api.hosted_extractor_mappings import HostedExtractorMappingsAPI
Expand Down Expand Up @@ -182,6 +183,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.tool.events = MagicMock(spec_set=EventsAPI)
self.tool.functions = MagicMock(spec=FunctionsAPI)
self.tool.functions.schedules = MagicMock(spec_set=FunctionSchedulesAPI)
self.tool.groups = MagicMock(spec_set=GroupsAPI)
self.tool.search_configurations = MagicMock(spec_set=SearchConfigurationsAPI)
self.tool.simulators = MagicMock(spec=SimulatorsAPI)
self.tool.simulators.models = MagicMock(spec_set=SimulatorModelsAPI)
Expand Down
21 changes: 13 additions & 8 deletions cognite_toolkit/_cdf_tk/commands/dump_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import typer
from cognite.client import data_modeling as dm
from cognite.client.data_classes import (
Group,
GroupList,
filters,
)
from cognite.client.data_classes.data_modeling import ViewId
Expand Down Expand Up @@ -46,8 +44,10 @@
from cognite_toolkit._cdf_tk.client.resource_classes.dataset import DataSetResponse
from cognite_toolkit._cdf_tk.client.resource_classes.extraction_pipeline import ExtractionPipelineResponse
from cognite_toolkit._cdf_tk.client.resource_classes.function import FunctionResponse
from cognite_toolkit._cdf_tk.client.resource_classes.group import GroupResponse
from cognite_toolkit._cdf_tk.client.resource_classes.identifiers import (
ExternalId,
NameId,
WorkflowVersionId,
)
from cognite_toolkit._cdf_tk.client.resource_classes.instance_api import TypedViewReference
Expand Down Expand Up @@ -366,16 +366,16 @@ def __iter__(
class GroupFinder(ResourceFinder[tuple[str, ...]]):
def __init__(self, client: ToolkitClient, identifier: tuple[str, ...] | None = None):
super().__init__(client, identifier)
self.groups: list[Group] | None = None
self.groups: list[GroupResponse] | None = None

def _interactive_select(self) -> tuple[str, ...]:
groups = self.client.iam.groups.list(all=True)
groups = self.client.tool.groups.list(all_groups=True)
if not groups:
raise ToolkitMissingResourceError("No groups found")
groups_by_name: dict[str, list[Group]] = defaultdict(list)
groups_by_name: dict[str, list[GroupResponse]] = defaultdict(list)
for group in groups:
groups_by_name[group.name].append(group)
selected_groups: list[list[Group]] | None = questionary.checkbox(
selected_groups: list[list[GroupResponse]] | None = questionary.checkbox(
"Which group(s) would you like to dump?",
choices=[
Choice(f"{group_name} ({len(group_list)} group{'s' if len(group_list) > 1 else ''})", value=group_list)
Expand All @@ -393,9 +393,14 @@ def __iter__(
) -> Iterator[tuple[Sequence[Hashable], Sequence[ResourceResponseProtocol] | None, ResourceCRUD, None | str]]:
self.identifier = self._selected()
if self.groups:
yield [], GroupList(self.groups), GroupCRUD.create_loader(self.client), None
yield (
[],
[group for group in self.groups if group.name in self.identifier],
GroupCRUD.create_loader(self.client),
None,
)
else:
yield list(self.identifier), None, GroupCRUD.create_loader(self.client), None
yield [NameId(name=name) for name in self.identifier], None, GroupCRUD.create_loader(self.client), None


class AgentFinder(ResourceFinder[tuple[str, ...]]):
Expand Down
Loading