Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f9fe77e
Added support for new task type Tag Detection for Workflows
ronpal Sep 24, 2025
ffa9c53
Fixed dump/load
ronpal Sep 25, 2025
931e450
separate class for filter property
ronpal Sep 25, 2025
a24148d
added task output classes
ronpal Sep 25, 2025
e70a9bd
Housekeeping
ronpal Sep 26, 2025
90c4f87
Update cognite/client/data_classes/workflows.py
ronpal Sep 26, 2025
4e72ff2
Update cognite/client/data_classes/workflows.py
ronpal Sep 26, 2025
fc78eb2
Docstring and snake_case
ronpal Sep 26, 2025
270e56f
Update cognite/client/data_classes/workflows.py
ronpal Sep 26, 2025
c144e21
Update cognite/client/data_classes/workflows.py
ronpal Sep 26, 2025
b28005d
following suggestions
ronpal Sep 26, 2025
2400f87
Merge remote-tracking branch 'refs/remotes/origin/tag-detection-workf…
ronpal Sep 26, 2025
e1d5db2
Docstring
ronpal Sep 26, 2025
c446338
Merge branch 'master' into tag-detection-workflow-task
ronpal Sep 26, 2025
ca1205e
Added new error message property to the TagDetectionJob class
ronpal Oct 9, 2025
48724ff
Fixed docstring precision
ronpal Oct 9, 2025
3f79906
Merge branch 'master' into tag-detection-workflow-task
ronpal Oct 9, 2025
322aac2
Added additional TagDetectionStatus statuses
ronpal Oct 9, 2025
1f1dadb
Added errormessage to dump method
ronpal Oct 9, 2025
d7cf8ac
Added rountrip load/dump for ABC dataclasses in workflows
ronpal Oct 9, 2025
2f6103d
Merge remote-tracking branch 'origin/master' into tag-detection-workf…
ronpal Oct 20, 2025
7c70f7f
added errorMessage to filePageRange dataclass
ronpal Oct 20, 2025
16d9738
Updated test
ronpal Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 221 additions & 1 deletion cognite/client/data_classes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
WriteableCogniteResource,
WriteableCogniteResourceList,
)
from cognite.client.data_classes.data_modeling import NodeId, ViewId
from cognite.client.data_classes.data_modeling.query import Query, ResultSetExpression, Select
from cognite.client.data_classes.filters import Filter
from cognite.client.data_classes.simulators.runs import (
SimulationInputOverride,
)
Expand All @@ -43,6 +45,21 @@

WorkflowStatus: TypeAlias = Literal["completed", "failed", "running", "terminated", "timed_out"]

TagDetectionStatus: TypeAlias = Literal[
"Queued",
"Distributing",
"Distributed",
"Loading Entities",
"Loaded Entities",
"Running",
"Detected",
"Annotated",
"Collecting",
"Completed",
"Failed",
"Timeout",
]


class WorkflowCore(WriteableCogniteResource["WorkflowUpsert"], ABC):
def __init__(self, external_id: str, description: str | None = None, data_set_id: int | None = None) -> None:
Expand Down Expand Up @@ -136,7 +153,7 @@ def as_write(self) -> WorkflowUpsertList:
return WorkflowUpsertList([workflow.as_write() for workflow in self.data])


ValidTaskType = Literal["function", "transformation", "cdf", "dynamic", "subworkflow", "simulation"]
ValidTaskType = Literal["function", "transformation", "cdf", "dynamic", "subworkflow", "simulation", "tagDetection"]


class WorkflowTaskParameters(CogniteObject, ABC):
Expand Down Expand Up @@ -166,6 +183,8 @@ def load_parameters(cls, data: dict) -> WorkflowTaskParameters:
return SubworkflowReferenceParameters._load(parameters)
elif type_ == "simulation":
return SimulationTaskParameters._load(parameters)
elif type_ == "tagDetection":
return TagDetectionTaskParameters._load(parameters)
else:
raise ValueError(f"Unknown task type: {type_}. Expected {ValidTaskType}")

Expand Down Expand Up @@ -513,6 +532,106 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:
}


class TagDetectionTaskEntityFilter(CogniteObject):
def __init__(self, view: ViewId, filters: Filter, search_field: str) -> None:
self.view = view
self.filters = filters
self.search_field = search_field

@classmethod
def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self:
return cls(
view=ViewId.load(resource["view"]),
filters=Filter.load(resource["filters"]),
search_field=resource["searchField"],
)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
return {
"view": self.view.dump(camel_case=camel_case, include_type=False),
"filters": self.filters.dump(),
"searchField": self.search_field,
}


class TagDetectionTaskParameters(WorkflowTaskParameters):
"""
The tag detection task parameters are used to specify a tag detection task.

Args:
file_instance_ids (list[NodeId] | str): List of files to detect tags in. A minimum of 1 file is expected. Can be a reference.
entity_filters (list[TagDetectionTaskEntityFilter]): Entity search specification(s) used to fetch DMS entities to match on. Must contain between 1 and 10 filters.
min_tokens (int | None): Each detected item must match the detected entity on at least this number of tokens. A token is a substring of consecutive letters or digits.
partial_match (bool | None): Allow partial (fuzzy) matching of entities in the engineering diagrams. Creates a match only when it is possible to do so unambiguously.
write_annotations (bool): Whether annotations should be automatically be written for the files

Note:
A Reference is an expression that allows dynamically injecting input to a task during execution.
References can be used to reference the input of the Workflow, the output of a previous task in the Workflow,
or the input of a previous task in the Workflow. Note that the injected value must be valid in the context of
the property it is injected into. Example Task reference: ${myTaskExternalId.output.someKey} Example Workflow input reference: ${workflow.input.myKey}

"""

task_type = "tagDetection"

def __init__(
self,
file_instance_ids: list[NodeId] | str,
entity_filters: list[TagDetectionTaskEntityFilter],
min_tokens: int | None = None,
partial_match: bool | None = None,
write_annotations: bool = False,
) -> None:
self.file_instance_ids = file_instance_ids
self.entity_filters = entity_filters
self.min_tokens = min_tokens
self.partial_match = partial_match
self.write_annotations = write_annotations

@classmethod
def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self:
tag_detection = resource[cls.task_type]
file_instance_ids: list[NodeId] | str
if isinstance(tag_detection["fileInstanceIds"], str):
file_instance_ids = tag_detection["fileInstanceIds"]
elif isinstance(tag_detection["fileInstanceIds"], list):
file_instance_ids = [NodeId.load(file_instance_id) for file_instance_id in tag_detection["fileInstanceIds"]]
else:
raise ValueError(f"Invalid file instance ids: {tag_detection['fileInstanceIds']}")

entity_filters: list[TagDetectionTaskEntityFilter] = [
TagDetectionTaskEntityFilter.load(item) for item in tag_detection["entityFilters"]
]

return cls(
file_instance_ids=file_instance_ids,
entity_filters=entity_filters,
min_tokens=tag_detection["minTokens"],
partial_match=tag_detection["partialMatch"],
write_annotations=tag_detection["writeAnnotations"],
Comment on lines +610 to +612
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _load method should handle optional fields gracefully. Accessing minTokens, partialMatch, and writeAnnotations with [] will raise a KeyError if they are not present in the API response. Since these are optional in __init__, you should use the .get() method to avoid potential runtime errors.

Suggested change
min_tokens=tag_detection["minTokens"],
partial_match=tag_detection["partialMatch"],
write_annotations=tag_detection["writeAnnotations"],
min_tokens=tag_detection.get("minTokens"),
partial_match=tag_detection.get("partialMatch"),
write_annotations=tag_detection.get("writeAnnotations", False),

)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
file_instance_ids: list[dict[str, str]] | str
if isinstance(self.file_instance_ids, str):
file_instance_ids = self.file_instance_ids
else:
file_instance_ids = [file_instance_id.dump(camel_case) for file_instance_id in self.file_instance_ids]

entity_filters = [ef.dump(camel_case) for ef in self.entity_filters]

return {
self.task_type: {
"fileInstanceIds": file_instance_ids,
"entityFilters": entity_filters,
"minTokens": self.min_tokens,
"partialMatch": self.partial_match,
"writeAnnotations": self.write_annotations,
}
}


class WorkflowTask(CogniteResource):
"""
This class represents a workflow task.
Expand Down Expand Up @@ -615,6 +734,8 @@ def load_output(cls, data: dict) -> WorkflowTaskOutput:
return SubworkflowTaskOutput.load(data)
elif task_type == "simulation":
return SimulationTaskOutput.load(data)
elif task_type == "tagDetection":
return TagDetectionTaskOutput.load(data)
else:
raise ValueError(f"Unknown task type: {task_type}")

Expand Down Expand Up @@ -779,6 +900,105 @@ def dump(self, camel_case: bool = False) -> dict[str, Any]:
return {}


class PageRange(CogniteObject):
def __init__(self, begin: int, end: int) -> None:
self.begin = begin
self.end = end

@classmethod
def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self:
return cls(resource["begin"], resource["end"])


class TagDetectionJobFilePageRange(CogniteObject):
"""
A list of file page ranges that is being processed or was processed by the job.

Args:
instanceId (NodeId): The identifier of the instance.
page_range (PageRange): No description.
error_message (str | None): Describes why the page range failed to be processed in case of page range processing failure.
"""

def __init__(self, instanceId: NodeId, page_range: PageRange, error_message: str | None) -> None:
self.instanceId = instanceId
self.page_range = page_range
self.error_message = error_message
Comment on lines +923 to +926
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parameter instanceId and attribute self.instanceId should be instance_id to adhere to the snake_case naming convention specified in the style guide.1 You will also need to update the dump method to use self.instance_id.

Suggested change
def __init__(self, instanceId: NodeId, page_range: PageRange, error_message: str | None) -> None:
self.instanceId = instanceId
self.page_range = page_range
self.error_message = error_message
def __init__(self, instance_id: NodeId, page_range: PageRange, error_message: str | None) -> None:
self.instance_id = instance_id
self.page_range = page_range
self.error_message = error_message

Style Guide References

Footnotes

  1. The style guide requires snake_case for variable and function names.


@classmethod
def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self:
return cls(
NodeId.load(resource["instanceId"]),
PageRange._load(resource["pageRange"]),
resource.get("errorMessage"),
)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
return {
"instanceId": self.instanceId.dump(camel_case=camel_case),
"pageRange": self.page_range.dump(camel_case=camel_case),
"errorMessage": self.error_message,
}


class TagDetectionJob(CogniteObject):
"""A tag detection job.

Args:
jobId (int): The identifier of the tag detection job.
status (TagDetectionStatus): The last observed status of the job.
filePageRanges (list[TagDetectionJobFilePageRange]): File page ranges that are or were processed by the job.
errorMessage (str | None): Describes the job failure reason in case of job failure.
"""

def __init__(
self,
jobId: int,
status: TagDetectionStatus,
filePageRanges: list[TagDetectionJobFilePageRange],
errorMessage: str | None,
) -> None:
self.jobId = jobId
self.status = status
self.filePageRanges = filePageRanges
self.errorMessage = errorMessage
Comment on lines +954 to +964
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parameters jobId, filePageRanges, and errorMessage and their corresponding attributes should be job_id, file_page_ranges, and error_message to follow the snake_case naming convention from the style guide.1 You will also need to update the _load and dump methods to use these new attribute names.

Suggested change
def __init__(
self,
jobId: int,
status: TagDetectionStatus,
filePageRanges: list[TagDetectionJobFilePageRange],
errorMessage: str | None,
) -> None:
self.jobId = jobId
self.status = status
self.filePageRanges = filePageRanges
self.errorMessage = errorMessage
def __init__(
self,
job_id: int,
status: TagDetectionStatus,
file_page_ranges: list[TagDetectionJobFilePageRange],
error_message: str | None,
) -> None:
self.job_id = job_id
self.status = status
self.file_page_ranges = file_page_ranges
self.error_message = error_message

Style Guide References

Footnotes

  1. The style guide requires snake_case for variable and function names.


@classmethod
def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> Self:
filePageRanges = [
TagDetectionJobFilePageRange.load(filePageRange) for filePageRange in resource["filePageRanges"]
]

return cls(resource["jobId"], resource["status"], filePageRanges, resource.get("errorMessage"))

def dump(self, camel_case: bool = True) -> dict[str, Any]:
return {
"jobId": self.jobId,
"status": self.status,
"filePageRanges": [filePageRange.dump(camel_case) for filePageRange in self.filePageRanges],
"errorMessage": self.errorMessage,
}


class TagDetectionTaskOutput(WorkflowTaskOutput):
"""
The tag detection task output is used to specify the output of tag detection task.
"""

def __init__(self, jobs: list[TagDetectionJob]) -> None:
self.jobs = jobs

@classmethod
def load(cls, data: dict[str, Any]) -> TagDetectionTaskOutput:
output = data["output"]
return cls([TagDetectionJob.load(tagDetectionJob) for tagDetectionJob in output["jobs"]])

def dump(self, camel_case: bool = False) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The default value for camel_case in the dump method is False, which is inconsistent with most other WorkflowTaskOutput subclasses that default to True. For consistency across the SDK, it's better to use True as the default.

Suggested change
def dump(self, camel_case: bool = False) -> dict[str, Any]:
def dump(self, camel_case: bool = True) -> dict[str, Any]:

return {
"jobs": [tagDetectionJob.dump(camel_case) for tagDetectionJob in self.jobs],
}


class WorkflowTaskExecution(CogniteObject):
"""
This class represents a task execution.
Expand Down
44 changes: 44 additions & 0 deletions tests/tests_unit/test_data_classes/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FunctionTaskParameters,
SimulationInputOverride,
SimulationTaskParameters,
TagDetectionTaskOutput,
TransformationTaskOutput,
TransformationTaskParameters,
WorkflowDefinition,
Expand Down Expand Up @@ -49,6 +50,49 @@ class TestWorkflowTaskOutput:
def test_serialization(self, output: WorkflowTaskOutput, expected: dict):
assert output.dump(camel_case=True) == expected

@pytest.mark.parametrize(
["data", "cls"],
[
(
{
"taskType": "function",
"output": {"callId": 123, "functionId": 3456, "response": {"test": 1}},
},
FunctionTaskOutput,
),
({"taskType": "dynamic", "output": {}}, DynamicTaskOutput),
({"taskType": "cdf", "output": {"response": {"test": 1}, "statusCode": 200}}, CDFTaskOutput),
({"taskType": "transformation", "output": {"jobId": 789}}, TransformationTaskOutput),
(
{
"taskType": "tagDetection",
"output": {
"jobs": [
{
"jobId": 321,
"status": "Completed",
"filePageRanges": [
{
"instanceId": {"space": "sp", "externalId": "id", "type": "node"},
"pageRange": {"begin": 1, "end": 5},
"errorMessage": None,
}
],
"errorMessage": None,
}
]
},
},
TagDetectionTaskOutput,
),
],
ids=["function", "dynamic", "cdf", "transformation", "tagDetection"],
)
def test_serialization_roundtrip(self, data: dict, cls: type[WorkflowTaskOutput]):
task_output = WorkflowTaskOutput.load_output(data)
assert isinstance(task_output, cls)
assert task_output.dump(camel_case=True) == data["output"]


class TestWorkflowId:
@pytest.mark.parametrize(
Expand Down