diff --git a/cognite/client/data_classes/workflows.py b/cognite/client/data_classes/workflows.py index fa35416476..f485f4fcbd 100644 --- a/cognite/client/data_classes/workflows.py +++ b/cognite/client/data_classes/workflows.py @@ -19,7 +19,9 @@ WriteableCogniteResource, WriteableCogniteResourceList, ) +from cognite.client.data_classes.data_modeling import NodeId, ViewId from cognite.client.data_classes.data_modeling.query import Query, ResultSetExpression, Select +from cognite.client.data_classes.filters import Filter from cognite.client.data_classes.simulators.runs import ( SimulationInputOverride, ) @@ -43,6 +45,21 @@ WorkflowStatus: TypeAlias = Literal["completed", "failed", "running", "terminated", "timed_out"] +TagDetectionStatus: TypeAlias = Literal[ + "Queued", + "Distributing", + "Distributed", + "Loading Entities", + "Loaded Entities", + "Running", + "Detected", + "Annotated", + "Collecting", + "Completed", + "Failed", + "Timeout", +] + class WorkflowCore(WriteableCogniteResource["WorkflowUpsert"], ABC): def __init__(self, external_id: str, description: str | None = None, data_set_id: int | None = None) -> None: @@ -136,7 +153,7 @@ def as_write(self) -> WorkflowUpsertList: return WorkflowUpsertList([workflow.as_write() for workflow in self.data]) -ValidTaskType = Literal["function", "transformation", "cdf", "dynamic", "subworkflow", "simulation"] +ValidTaskType = Literal["function", "transformation", "cdf", "dynamic", "subworkflow", "simulation", "tagDetection"] class WorkflowTaskParameters(CogniteObject, ABC): @@ -166,6 +183,8 @@ def load_parameters(cls, data: dict) -> WorkflowTaskParameters: return SubworkflowReferenceParameters._load(parameters) elif type_ == "simulation": return SimulationTaskParameters._load(parameters) + elif type_ == "tagDetection": + return TagDetectionTaskParameters._load(parameters) else: raise ValueError(f"Unknown task type: {type_}. Expected {ValidTaskType}") @@ -513,6 +532,106 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]: } +class TagDetectionTaskEntityFilter(CogniteObject): + def __init__(self, view: ViewId, filters: Filter, search_field: str) -> None: + self.view = view + self.filters = filters + self.search_field = search_field + + @classmethod + def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self: + return cls( + view=ViewId.load(resource["view"]), + filters=Filter.load(resource["filters"]), + search_field=resource["searchField"], + ) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "view": self.view.dump(camel_case=camel_case, include_type=False), + "filters": self.filters.dump(), + "searchField": self.search_field, + } + + +class TagDetectionTaskParameters(WorkflowTaskParameters): + """ + The tag detection task parameters are used to specify a tag detection task. + + Args: + file_instance_ids (list[NodeId] | str): List of files to detect tags in. A minimum of 1 file is expected. Can be a reference. + entity_filters (list[TagDetectionTaskEntityFilter]): Entity search specification(s) used to fetch DMS entities to match on. Must contain between 1 and 10 filters. + min_tokens (int | None): Each detected item must match the detected entity on at least this number of tokens. A token is a substring of consecutive letters or digits. + partial_match (bool | None): Allow partial (fuzzy) matching of entities in the engineering diagrams. Creates a match only when it is possible to do so unambiguously. + write_annotations (bool): Whether annotations should be automatically be written for the files + + Note: + A Reference is an expression that allows dynamically injecting input to a task during execution. + References can be used to reference the input of the Workflow, the output of a previous task in the Workflow, + or the input of a previous task in the Workflow. Note that the injected value must be valid in the context of + the property it is injected into. Example Task reference: ${myTaskExternalId.output.someKey} Example Workflow input reference: ${workflow.input.myKey} + + """ + + task_type = "tagDetection" + + def __init__( + self, + file_instance_ids: list[NodeId] | str, + entity_filters: list[TagDetectionTaskEntityFilter], + min_tokens: int | None = None, + partial_match: bool | None = None, + write_annotations: bool = False, + ) -> None: + self.file_instance_ids = file_instance_ids + self.entity_filters = entity_filters + self.min_tokens = min_tokens + self.partial_match = partial_match + self.write_annotations = write_annotations + + @classmethod + def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self: + tag_detection = resource[cls.task_type] + file_instance_ids: list[NodeId] | str + if isinstance(tag_detection["fileInstanceIds"], str): + file_instance_ids = tag_detection["fileInstanceIds"] + elif isinstance(tag_detection["fileInstanceIds"], list): + file_instance_ids = [NodeId.load(file_instance_id) for file_instance_id in tag_detection["fileInstanceIds"]] + else: + raise ValueError(f"Invalid file instance ids: {tag_detection['fileInstanceIds']}") + + entity_filters: list[TagDetectionTaskEntityFilter] = [ + TagDetectionTaskEntityFilter.load(item) for item in tag_detection["entityFilters"] + ] + + return cls( + file_instance_ids=file_instance_ids, + entity_filters=entity_filters, + min_tokens=tag_detection["minTokens"], + partial_match=tag_detection["partialMatch"], + write_annotations=tag_detection["writeAnnotations"], + ) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + file_instance_ids: list[dict[str, str]] | str + if isinstance(self.file_instance_ids, str): + file_instance_ids = self.file_instance_ids + else: + file_instance_ids = [file_instance_id.dump(camel_case) for file_instance_id in self.file_instance_ids] + + entity_filters = [ef.dump(camel_case) for ef in self.entity_filters] + + return { + self.task_type: { + "fileInstanceIds": file_instance_ids, + "entityFilters": entity_filters, + "minTokens": self.min_tokens, + "partialMatch": self.partial_match, + "writeAnnotations": self.write_annotations, + } + } + + class WorkflowTask(CogniteResource): """ This class represents a workflow task. @@ -615,6 +734,8 @@ def load_output(cls, data: dict) -> WorkflowTaskOutput: return SubworkflowTaskOutput.load(data) elif task_type == "simulation": return SimulationTaskOutput.load(data) + elif task_type == "tagDetection": + return TagDetectionTaskOutput.load(data) else: raise ValueError(f"Unknown task type: {task_type}") @@ -779,6 +900,105 @@ def dump(self, camel_case: bool = False) -> dict[str, Any]: return {} +class PageRange(CogniteObject): + def __init__(self, begin: int, end: int) -> None: + self.begin = begin + self.end = end + + @classmethod + def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self: + return cls(resource["begin"], resource["end"]) + + +class TagDetectionJobFilePageRange(CogniteObject): + """ + A list of file page ranges that is being processed or was processed by the job. + + Args: + instanceId (NodeId): The identifier of the instance. + page_range (PageRange): No description. + error_message (str | None): Describes why the page range failed to be processed in case of page range processing failure. + """ + + def __init__(self, instanceId: NodeId, page_range: PageRange, error_message: str | None) -> None: + self.instanceId = instanceId + self.page_range = page_range + self.error_message = error_message + + @classmethod + def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self: + return cls( + NodeId.load(resource["instanceId"]), + PageRange._load(resource["pageRange"]), + resource.get("errorMessage"), + ) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "instanceId": self.instanceId.dump(camel_case=camel_case), + "pageRange": self.page_range.dump(camel_case=camel_case), + "errorMessage": self.error_message, + } + + +class TagDetectionJob(CogniteObject): + """A tag detection job. + + Args: + jobId (int): The identifier of the tag detection job. + status (TagDetectionStatus): The last observed status of the job. + filePageRanges (list[TagDetectionJobFilePageRange]): File page ranges that are or were processed by the job. + errorMessage (str | None): Describes the job failure reason in case of job failure. + """ + + def __init__( + self, + jobId: int, + status: TagDetectionStatus, + filePageRanges: list[TagDetectionJobFilePageRange], + errorMessage: str | None, + ) -> None: + self.jobId = jobId + self.status = status + self.filePageRanges = filePageRanges + self.errorMessage = errorMessage + + @classmethod + def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self: + filePageRanges = [ + TagDetectionJobFilePageRange.load(filePageRange) for filePageRange in resource["filePageRanges"] + ] + + return cls(resource["jobId"], resource["status"], filePageRanges, resource.get("errorMessage")) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "jobId": self.jobId, + "status": self.status, + "filePageRanges": [filePageRange.dump(camel_case) for filePageRange in self.filePageRanges], + "errorMessage": self.errorMessage, + } + + +class TagDetectionTaskOutput(WorkflowTaskOutput): + """ + The tag detection task output is used to specify the output of tag detection task. + """ + + def __init__(self, jobs: list[TagDetectionJob]) -> None: + self.jobs = jobs + + @classmethod + def load(cls, data: dict[str, Any]) -> TagDetectionTaskOutput: + output = data["output"] + return cls([TagDetectionJob.load(tagDetectionJob) for tagDetectionJob in output["jobs"]]) + + def dump(self, camel_case: bool = False) -> dict[str, Any]: + return { + "jobs": [tagDetectionJob.dump(camel_case) for tagDetectionJob in self.jobs], + } + + class WorkflowTaskExecution(CogniteObject): """ This class represents a task execution. diff --git a/tests/tests_unit/test_data_classes/test_workflows.py b/tests/tests_unit/test_data_classes/test_workflows.py index 840a45ad58..4a39eb3356 100644 --- a/tests/tests_unit/test_data_classes/test_workflows.py +++ b/tests/tests_unit/test_data_classes/test_workflows.py @@ -12,6 +12,7 @@ FunctionTaskParameters, SimulationInputOverride, SimulationTaskParameters, + TagDetectionTaskOutput, TransformationTaskOutput, TransformationTaskParameters, WorkflowDefinition, @@ -49,6 +50,49 @@ class TestWorkflowTaskOutput: def test_serialization(self, output: WorkflowTaskOutput, expected: dict): assert output.dump(camel_case=True) == expected + @pytest.mark.parametrize( + ["data", "cls"], + [ + ( + { + "taskType": "function", + "output": {"callId": 123, "functionId": 3456, "response": {"test": 1}}, + }, + FunctionTaskOutput, + ), + ({"taskType": "dynamic", "output": {}}, DynamicTaskOutput), + ({"taskType": "cdf", "output": {"response": {"test": 1}, "statusCode": 200}}, CDFTaskOutput), + ({"taskType": "transformation", "output": {"jobId": 789}}, TransformationTaskOutput), + ( + { + "taskType": "tagDetection", + "output": { + "jobs": [ + { + "jobId": 321, + "status": "Completed", + "filePageRanges": [ + { + "instanceId": {"space": "sp", "externalId": "id", "type": "node"}, + "pageRange": {"begin": 1, "end": 5}, + "errorMessage": None, + } + ], + "errorMessage": None, + } + ] + }, + }, + TagDetectionTaskOutput, + ), + ], + ids=["function", "dynamic", "cdf", "transformation", "tagDetection"], + ) + def test_serialization_roundtrip(self, data: dict, cls: type[WorkflowTaskOutput]): + task_output = WorkflowTaskOutput.load_output(data) + assert isinstance(task_output, cls) + assert task_output.dump(camel_case=True) == data["output"] + class TestWorkflowId: @pytest.mark.parametrize(