Skip to content

Commit

Permalink
add script for clearing AI terms flags (#24)
Browse files Browse the repository at this point in the history
* add script for clearing AI terms flags

* remove loguru

* add readme
  • Loading branch information
hsheth2 authored Feb 21, 2025
1 parent ea7d2ce commit 7bd737e
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ai_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Tools for AI Classification

Docs on this feature: https://datahubproject.io/docs/automations/ai-term-suggestion/

- The `clear.py` script has a few subcommands that can be used reset internal states around the classification feature.
168 changes: 168 additions & 0 deletions ai_classification/clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#!/usr/bin/env -S uv --quiet run --script
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "acryl-datahub",
# "acryl-datahub-cloud",
# ]
# ///

import logging

import acryl_datahub_cloud.metadata.schema_classes as models
import click
from datahub.emitter.mce_builder import get_sys_time
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import get_default_graph
from datahub.metadata.urns import DatasetUrn

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

gql = """\
query listTermProposals($endTimestampMillis: Long, $count: Int) {
listActionRequests(input: {
type: TERM_ASSOCIATION
status: PENDING
count: $count
endTimestampMillis: $endTimestampMillis
}) {
start
count
total
actionRequests {
urn
type
status
entity {
urn
}
subResourceType
subResource
params {
glossaryTermProposal {
glossaryTerm {
urn
}
}
}
origin
created {
time
}
}
}
}
mutation rejectProposal($urn: String!) {
rejectProposal(urn: $urn)
}
"""


@click.group()
def main():
pass


def _clear_classified_flag_mcp(urn: str) -> MetadataChangeProposalWrapper:
return MetadataChangeProposalWrapper(
entityUrn=urn,
aspect=models.EntityInferenceMetadataClass(
# TODO: Replace this with a patch.
glossaryTermsInference=models.InferenceGroupMetadataClass(
lastInferredAt=get_sys_time(),
version=0,
),
),
)


@main.command()
@click.option("--dry-run", is_flag=True, default=False)
def clear_all_proposals(dry_run: bool):
graph = get_default_graph()
logger.info(f"Using graph {graph}")

proposals_rejected = 0
end_timestamp = None
while True:
# Results are ordered by time, newest first.
res = graph.execute_graphql(
gql,
operation_name="listTermProposals",
variables={"endTimestampMillis": end_timestamp, "count": 100},
)["listActionRequests"]
if res["total"] == 0 or not (proposals := res["actionRequests"]):
break

for proposal in proposals:
if proposal["origin"] != "INFERRED":
continue

action_request_urn = proposal["urn"]
entity_urn = proposal["entity"]["urn"]
logger.info(
("DRY RUN: " if dry_run else "")
+ f"Rejecting proposal {action_request_urn} on "
f"{entity_urn} {proposal['subResourceType']} {proposal['subResource']} "
f"to add {proposal['params']['glossaryTermProposal']['glossaryTerm']['urn']}"
)
proposals_rejected += 1

if not dry_run:
graph.execute_graphql(
gql,
operation_name="rejectProposal",
variables={"urn": action_request_urn},
)

end_timestamp = proposals[-1]["created"]["time"]

logger.info(
("DRY RUN: " if dry_run else "") + f"Rejected {proposals_rejected} proposals"
)


@main.command()
@click.option("--dry-run", is_flag=True, default=False)
def clear_all_flags(dry_run: bool):
graph = get_default_graph()
logger.info(f"Using graph {graph}")

cleared = 0
with graph.make_rest_sink() as sink:
for urn in graph.get_urns_by_filter(
entity_types=[DatasetUrn.ENTITY_TYPE],
batch_size=100,
extraFilters=[
{
"field": "glossaryTermsVersion",
"values": ["0"],
"condition": "GREATER_THAN",
}
],
):
logger.info(f"Clearing flag for {urn}")
if not dry_run:
sink.emit_async(_clear_classified_flag_mcp(urn))
cleared += 1

logger.info(
("DRY RUN: " if dry_run else "") + f"Cleared flags on {cleared} entities"
)


@main.command()
@click.option("--urn", required=True)
def clear_single_flag(urn: str):
assert DatasetUrn.from_string(urn)

graph = get_default_graph()
logger.info(f"Using graph {graph}")

graph.emit_mcp(_clear_classified_flag_mcp(urn))


if __name__ == "__main__":
main()

0 comments on commit 7bd737e

Please sign in to comment.