From 30a8e00ec0fb97b0a557eb67166a5fbcc2e885bd Mon Sep 17 00:00:00 2001 From: Jason Walker Date: Tue, 28 Nov 2023 11:22:47 -0500 Subject: [PATCH] Add IDRT connector for duplicate contact detection --- docs/identity_resolution.rst | 308 +++++++++++++++++ docs/index.rst | 1 + parsons/__init__.py | 1 + parsons/identity_resolution/__init__.py | 3 + parsons/identity_resolution/idrt_connector.py | 320 ++++++++++++++++++ .../identity_resolution/idrt_db_adapter.py | 46 +++ setup.py | 8 +- 7 files changed, 686 insertions(+), 1 deletion(-) create mode 100644 docs/identity_resolution.rst create mode 100644 parsons/identity_resolution/__init__.py create mode 100644 parsons/identity_resolution/idrt_connector.py create mode 100644 parsons/identity_resolution/idrt_db_adapter.py diff --git a/docs/identity_resolution.rst b/docs/identity_resolution.rst new file mode 100644 index 0000000000..2cbe1fb283 --- /dev/null +++ b/docs/identity_resolution.rst @@ -0,0 +1,308 @@ +Identity Resolution +=================== + +******** +Overview +******** + +IDRT (Identity Resolution Transformer) is a library that uses neural networks to identify +duplicate contacts by their contact data, like name and phone. + +IDRT contains two main sub-packages: IDRT and IDRT.algorithm. + +* IDRT contains tools that allow you to train your own model based on existing + duplicate/distinct data. + + * IDRT training produces two models that are used by the algorithm: an *encoder* model + and a *classifier* model. + +* IDRT.algorithm provides functions to run an algorithm that use an IDRT model to + perform an efficient duplicate search on a database of contacts. + +This connector does not provide access to the model training portion of the +library. To use it, you must have trained models at hand. It does provide +Parsons integration to the algorithm portion of IDRT, allowing you to easily identify +duplicate contacts in your database. + +For more information, see: https://github.com/Jason94/identity-resolution + +============ +Installation +============ + +**Step 0: Install PyTorch** + +If you're using IDRT for the first time, you can skip to step 1. + +IDRT uses the *PyTorch* Python library to build its neural networks. +Neural networks are significantly faster if they're running on graphics cards (GPUs) +as opposed to traditional processors (CPUs). If you have a graphics card in the +computer (or cloud computation platform) where you will be running IDRT, you can +take advantage of GPU hardware by installing the **CUDA** version of PyTorch. + +Visit `the PyTorch installation webpage `_. Select `CUDA 11.8` +and your preferred operating system. It will give you a `pip` command that you can use to install the GPU +version of PyTorch. Run this command. + +When you are running the IDRT algorithm, this line will appear in the logs if it is running on CPU hardware: + +.. code-block:: + + INFO:idrt.algorithm.utils:Found device cpu + +and this line will appear if you are running on GPU hardware: + +.. code-block:: + + INFO:idrt.algorithm.utils:Found device cuda + +**Step 1: Install the IDRT Parsons connector** + +Because IDRT pulls in several large dependencies, it is not part of the standard Parsons installation. +You can install the IDRT Parsons connector by running this command: + +.. code-block:: + + pip install parsons[idr] + +If you are on newer versions of Pip (>= 20.3) your installation might take an inordinately long time. If +your install is taking a long time and you see log messages like ``INFO: This is taking longer than usual. +You might need to provide the dependency resolver with stricter constraints to reduce runtime. See +https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.`` +then you can use this workaround to install Parsons and IDRT separately: + +.. code-block:: + + pip install parsons + pip install idrt[algorithm] + pip install parsons[idr] --no-deps + +========== +Quickstart +========== + +These scripts will run out-of-the-box, but they assume that you have model files on hand. +The scripts will look for model files named `"encoder.pt"` and `"classifier.pt"` in the directory +where you run the script. If you're running these scripts on some database orchestration platforms, +like Civis Platform, you might need to use URL's instead of file paths to serve the algorithm +files stored in the platform. In that case, you should change the scripts to use the `encoder_url` +parameter for step 1 and the `encoder_url` and `classifier_url` parameters for step 2. + +**Example 1** + +This script matches all of the contacts stored in a SQL table or view. It will look for an +environmental variable, `SCHEMA`, and a contact data table/view, `DATA_TABLE` in that schema. The +data table must have a `primary_key` column, a `pool` column (it can be set to `NULL`), and a +column for all of the fields that the model is trained to expect (email, etc). They must also do any +pre-processing that the model has been trained to expect. A common example of pre-processing is +removing any parenthesis, hyphens, and spaces from phone numbers. (See documentation for your +model for more details on individual fields.) Usually we create SQL views to format data from your +sources in a way that can be fed into IDRT, and pass the SQL views into the algorithm as the +`DATA_TABLE`. + +The script will run step one, producing an encoding of the first 10,000 rows. +It will then run step two, which will perform a duplicate search among those 10,000 rows. + +Finally, it will download all of the contcats that were determined to be duplicates and save them +to a CSV file. + +.. code-block:: python + + import os + import logging + + from parsons.databases.discover_database import discover_database + from parsons import IDRT + + logging.basicConfig() + logging.getLogger("idrt.algorithm.prepare_data").setLevel(logging.INFO) + logging.getLogger("idrt.algorithm.run_search").setLevel(logging.INFO) + + SCHEMA = os.environ["SCHEMA"] + DATA_TABLE = os.environ["DATA_TABLE"] + ENCODER_PATH = os.path.join(os.getcwd(), "encoder.pt") + CLASSIFIER_PATH = os.path.join(os.getcwd(), "classifier.pt") + + full_data_table = SCHEMA + "." + DATA_TABLE + + db = discover_database() + idrt = IDRT(db, output_schema=SCHEMA) + + idrt.step_1_encode_contacts(full_data_table, limit=10_000, encoder_path=ENCODER_PATH) + idrt.step_2_run_search(encoder_path=ENCODER_PATH, classifier_path=CLASSIFIER_PATH) + + # For some reason the Parsons Redshift connector uploads boolean datatypes as strings, + # so we have to compare to the string 'True' if we're running on Redshift. + duplicates = db.query( + f""" SELECT d.classification_score, c1.*, c2.* + FROM {SCHEMA}.idr_dups d + JOIN {full_data_table} c1 + ON c1.primary_key = d.pkey1 + AND c1.pool = d.pool1 + JOIN {full_data_table} c2 + ON c2.primary_key = d.pkey2 + AND c2.pool = d.pool2 + WHERE d.matches = 'True'; + """ + ) + duplicates.to_csv("duplicates.csv") + +**Example 2** + +This script extends the previous one to guarantee that it will finish matching all +of the contacts in one run of the script. After we complete step 2, we check to see +if there are any contacts that hadn't been encoded in step 1. *(Encoded contacts are stored +in the `idr_out` table and contacts that the model couldn't read are stored in the +`idr_invalid_contacts` table.)* If we find any, we repeat steps 1 & 2 until all contacts have +been processed and matched. + +.. code-block:: python + + import os + import logging + + from parsons.databases.discover_database import discover_database + from parsons.databases.database_connector import DatabaseConnector + from parsons import IDRT + + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + logging.getLogger("idrt.algorithm.prepare_data").setLevel(logging.INFO) + logging.getLogger("idrt.algorithm.run_search").setLevel(logging.INFO) + + SCHEMA = os.environ["SCHEMA"] + DATA_TABLE = os.environ["DATA_TABLE"] + ENCODER_PATH = os.path.join(os.getcwd(), "encoder.pt") + CLASSIFIER_PATH = os.path.join(os.getcwd(), "classifier.pt") + + full_data_table = SCHEMA + "." + DATA_TABLE + + db = discover_database() + idrt = IDRT(db, output_schema=SCHEMA) + + + def iterate_algorithm(db: DatabaseConnector): + idrt.step_1_encode_contacts(full_data_table, limit=300, encoder_path=ENCODER_PATH) + idrt.step_2_run_search(encoder_path=ENCODER_PATH, classifier_path=CLASSIFIER_PATH) + + # If we don't encounter any invalid contacts in the first iteration, + # the table might not exist yet. + if db.table_exists(SCHEMA + ".idr_invalid_contacts"): + remaining_contacts = db.query( + f""" + SELECT count(*) + FROM {full_data_table} c + LEFT JOIN {SCHEMA}.idr_out out + ON out.primary_key = c.primary_key + AND out.pool = c.pool + LEFT JOIN {SCHEMA}.idr_invalid_contacts inv + ON inv.primary_key = c.primary_key + AND inv.pool = c.pool + WHERE out.primary_key IS NULL + AND inv.primary_key IS NULL; + """ + ).first + else: + remaining_contacts = db.query( + f""" + SELECT count(*) + FROM {full_data_table} c + LEFT JOIN {SCHEMA}.idr_out out + ON out.primary_key = c.primary_key + AND out.pool = c.pool + WHERE out.primary_key IS NULL; + """ + ).first + + logging.info(f"{remaining_contacts} contacts remaining") + if remaining_contacts > 0: + iterate_algorithm(db) + + + iterate_algorithm(db) + + # For some reason the Parsons Redshift connector uploads boolean datatypes as strings, + # so we have to compare to the string 'True' if we're running on Redshift. + duplicates = db.query( + f""" SELECT d.classification_score, c1.*, c2.* + FROM {SCHEMA}.idr_dups d + JOIN {full_data_table} c1 + ON c1.primary_key = d.pkey1 + AND c1.pool = d.pool1 + JOIN {full_data_table} c2 + ON c2.primary_key = d.pkey2 + AND c2.pool = d.pool2 + WHERE d.matches = 'True'; + """ + ) + duplicates.to_csv("duplicates.csv") + +**Example 3** + +This script brings in the notion of *pools*. The simpler scripts above can identify +duplicates within one set of contacts. They cannot identify, for example, the contact +in your ActionKit data that best matches a given contact in your EveryAction data. +This kind of cross-matching can be accomplished using the source and search pools +arguments to step 2. + +The code below will run step 1 against two source tables, one containing the contact data +for EveryAction and one containing the contact data for ActionKit. These tables must +be formatted the same way as the previous ones. The EveryAction table must contain +`everyaction` in the `pool` column for all rows, and the ActionKit table must contain +`actionkit` in the `pool` column for all rows. + +.. code-block:: python + + import os + import logging + + from parsons.databases.discover_database import discover_database + from parsons import IDRT + + logging.basicConfig() + logging.getLogger("idrt.algorithm.prepare_data").setLevel(logging.INFO) + logging.getLogger("idrt.algorithm.run_search").setLevel(logging.INFO) + + SCHEMA = os.environ["SCHEMA"] + EA_DATA_TABLE = os.environ["EA_DATA_TABLE"] + AK_DATA_TABLE = os.environ["AK_DATA_TABLE"] + ENCODER_PATH = os.path.join(os.getcwd(), "encoder.pt") + CLASSIFIER_PATH = os.path.join(os.getcwd(), "classifier.pt") + + full_ea_table = SCHEMA + "." + EA_DATA_TABLE + full_ak_table = SCHEMA + "." + AK_DATA_TABLE + + db = discover_database() + idrt = IDRT(db, output_schema=SCHEMA) + + idrt.step_1_encode_contacts(full_ea_table, limit=10_000, encoder_path=ENCODER_PATH) + idrt.step_1_encode_contacts(full_ak_table, limit=10_000, encoder_path=ENCODER_PATH) + idrt.step_2_run_search( + encoder_path=ENCODER_PATH, + classifier_path=CLASSIFIER_PATH, + source_pool="everyaction", + search_pool="actionkit", + ) + + # For some reason the Parsons Redshift connector uploads boolean datatypes as strings, + # so we have to compare to the string 'True' if we're running on Redshift. + duplicates = db.query( + f""" SELECT d.classification_score, c1.*, c2.* + FROM {SCHEMA}.idr_dups d + JOIN {full_ea_table} c1 + ON c1.primary_key = d.pkey1 + AND c1.pool = d.pool1 + JOIN {full_ak_table} c2 + ON c2.primary_key = d.pkey2 + AND c2.pool = d.pool2 + WHERE d.matches = 'True'; + """ + ) + duplicates.to_csv("duplicates.csv") + +*** +API +*** + +.. autoclass :: parsons.IDRT + :inherited-members: diff --git a/docs/index.rst b/docs/index.rst index aed7ccca3a..a6343121a3 100755 --- a/docs/index.rst +++ b/docs/index.rst @@ -204,6 +204,7 @@ Indices and tables github google hustle + identity_resolution mailchimp mobilecommons mobilize_america diff --git a/parsons/__init__.py b/parsons/__init__.py index 76393fa8ac..129e0cb136 100644 --- a/parsons/__init__.py +++ b/parsons/__init__.py @@ -65,6 +65,7 @@ ("parsons.google.google_cloud_storage", "GoogleCloudStorage"), ("parsons.google.google_sheets", "GoogleSheets"), ("parsons.hustle.hustle", "Hustle"), + ("parsons.identity_resolution.idrt_connector", "IDRT"), ("parsons.mailchimp.mailchimp", "Mailchimp"), ("parsons.mobilecommons.mobilecommons", "MobileCommons"), ("parsons.mobilize_america.ma", "MobilizeAmerica"), diff --git a/parsons/identity_resolution/__init__.py b/parsons/identity_resolution/__init__.py new file mode 100644 index 0000000000..57ea005ef0 --- /dev/null +++ b/parsons/identity_resolution/__init__.py @@ -0,0 +1,3 @@ +from parsons.identity_resolution.idrt_connector import IDRT + +__all__ = ["IDRT"] diff --git a/parsons/identity_resolution/idrt_connector.py b/parsons/identity_resolution/idrt_connector.py new file mode 100644 index 0000000000..cfc9e53d1a --- /dev/null +++ b/parsons/identity_resolution/idrt_connector.py @@ -0,0 +1,320 @@ +import logging +import os +from typing import Optional + +from idrt.algorithm.utils import download_model, table_from_full_path +from idrt.algorithm import step_1_encode_contacts, step_2_run_search + +from parsons.databases.database_connector import DatabaseConnector +from parsons.identity_resolution.idrt_db_adapter import ParsonsDBAdapter + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class IDRT: + def __init__( + self, + db: DatabaseConnector, + output_schema: Optional[str] = None, + tokens_table_name: str = "idr_tokens", + encodings_table_name: str = "idr_out", + duplicates_table_name: str = "idr_dups", + enable_progress_bar: bool = True, + ): + """Create a new IDRT connector. + + Args: + db (DatabaseConnector): The database connection you want to use. + Must support the `upsert` method! + output_schema (Optional[str], optional): Schema, or database, + & schema, where intermediate and final output tables will + be stored. If none is provided, will look in the `OUTPUT_SCHEMA` + environmental variable. Defaults to None. + Ex: "my_orgs_contact_data" or "dev.contacts" + tokens_table_name (str, optional): Name of the table + in `output_schema` to store tokenizations of contact data. + Defaults to "idr_tokens". + encodings_table_name (str, optional): Name of the table + in `output_schema` to store vector encodings of contacts. + Defaults to "idr_out". + duplicates_table_name (str, optional): Name of the table + in `output_schema` to store the final duplicate evaluation + results. Defaults to "idr_dups". + enable_progress_bar (bool, optional): Show the progress bar while + running models. Recommend set to `False` when running in + non-standard terminal output environmnets, like Civis platform, + where it can draw a new progress bar on every line for every + row it calculates! Defaults to `True`. + """ + self.db = ParsonsDBAdapter(db) + + self.schema = ( + output_schema if output_schema is not None else os.environ["OUTPUT_SCHEMA"] + ) + self.tokens_table = table_from_full_path(self.schema + "." + tokens_table_name) + self.encodings_table = table_from_full_path( + self.schema + "." + encodings_table_name + ) + self.duplicates_table = table_from_full_path( + self.schema + "." + duplicates_table_name + ) + self.invalid_table = table_from_full_path(self.schema + ".idr_invalid_contacts") + self.enable_progress_bar = enable_progress_bar + + @staticmethod + def model_path(filename: str) -> str: + """Construct the full path to download a model in the case a path was not provided. + + Will download the file to the working directory of the program. + + Args: + filename (str): Filename of the download file. + + Returns: + str: Complete path for the download. + """ + return os.path.join(os.getcwd(), filename) + + def step_1_encode_contacts( + self, + data_table_name: Optional[str] = None, + batch_size: Optional[int] = None, + limit: Optional[int] = None, + encoder_url: Optional[str] = None, + encoder_path: Optional[str] = None, + ): + """Step 1 of the IDRT algorithm. Execute before step 2! + + This step downloads the contact data from `data_table_name`, encodes all of the rows, + and uploads the vector encodings of the contact data into the database. Those vector + encodings are used by step 2 of the algorithm to identify candidate duplicates. + + The size of the vectors produced for each contact encoding are determined by the encoder + model used in the algorithm. + + You must supply either `encoder_url` or `encoder_path` to specify where the encoder + model is located. + + Note: For large sets of data, you will need to run step 1 multiple times to make + sure that all of your contacts are encoding before step 2 will produce complete + results. The algorithm caches results appropriately so this will not cause problems + if you run step 2 against incomplete data. But be aware that you will probably need + to run step 1 multiple times before step 2 is producing complete results. To check + this, you can query for the number of primary keys in `data_table_name` that are not + present in the encodings table (passed into the IDRT class constructor). + + Args: + batch_size (Optional[int], optional): Number of contacts the model process + at once. Raising this will increase speed, with diminishing returns + based on your RAM/VRAM. If your program crashes, try lowering this. + If none is provided, will look in the `BATCH_SIZE` environmental variable. + Defaults to 16. + data_table_name (Optional[str], optional): Full SQL path of the table containing + formatted contact data to encode. If none is provided, will look in the + `DATA_TABLE` environmental variable. Defaults to None. + limit (Optional[int], optional): Number of contacts to process. Limiting will be + required if you have more contacts than available RAM/VRAM. The lower the limit, + the more times you will need to run step 1 to encode all of your contacts and + match your whole database in step 2. If your program crashes, try lowering this. + If none is provided, will look in the `LIMIT` environmental variable. + Defaults to 500,000. + encoder_url (Optional[str], optional): URL where the encoder model can be downloaded. + If absent, will look in the `ENCODER_URL` environmental variable. + Defaults to None. + encoder_path (Optional[str], optional): Path to an existing encoder model file. + If absent, will look in the `ENCODER_PATH` environmental variable. If no + `encoder_path` is provided, will download from the provided encoder_url. + Should be an absolute path. Defaults to None. + """ + batch_size = ( + batch_size if batch_size is not None else int(os.getenv("BATCH_SIZE", 16)) + ) + data_table = ( + table_from_full_path(data_table_name) + if data_table_name is not None + else table_from_full_path(os.environ["DATA_TABLE"]) + ) + limit = limit if limit is not None else int(os.getenv("LIMIT", 500_000)) + encoder_url = ( + encoder_url if encoder_url is not None else os.getenv("ENCODER_URL") + ) + encoder_path = ( + encoder_path if encoder_path is not None else os.getenv("ENCODER_PATH") + ) + + if encoder_path is None and encoder_url is not None: + encoder_path = IDRT.model_path("encoder.pt") + logger.debug( + f"Downloading encoder model from {encoder_url} to {encoder_path}" + ) + download_model(encoder_url, encoder_path) + elif encoder_path is None: + raise RuntimeError("Must supply encoder_path or encoder_url") + + step_1_encode_contacts( + self.db, + batch_size=batch_size, + data_table=data_table, + tokens_table=self.tokens_table, + output_table=self.encodings_table, + invalid_table=self.invalid_table, + limit=limit, + encoder_path=encoder_path, + enable_progress_bar=self.enable_progress_bar, + ) + + def step_2_run_search( + self, + threshold: Optional[float] = None, + classifier_threshold: Optional[float] = None, + n_closest: Optional[int] = None, + batch_size: Optional[int] = None, + search_pool: Optional[str] = None, + source_pool: Optional[str] = None, + n_trees: Optional[int] = None, + search_k: Optional[int] = None, + encoder_url: Optional[str] = None, + encoder_path: Optional[str] = None, + classifier_url: Optional[str] = None, + classifier_path: Optional[str] = None, + ): + """Step 2 of the IDRT algorithm. Execute after step 1! + + This step uses the work done in step 1 to efficietly identify duplicates for + contacts that have already been encoded. It takes any vectors that are within + `threshold` and compares them using the classifier model to produce a + classification score for that pair. Any pairs that are over `classifier_threshold` + are considered to be duplicates. + + You must supply either `encoder_url` or `encoder_path` to specify where the encoder + model is located, and you must supply either `classifier_url` or `classifier_path` to + specify where the classifier model is located. + + + Note: The distance used by the algorithm is the distance under the metric that the + model was trained with. This might not be standard Euclidean distance. For example, + the `idrt.CosineMetric` class treats pairs as being within the threshold if the + distance is _above_ the threshold. When in doubt, do not supply a value for + `classifier_threshold`, and the model will fall back on the default value stored in + the classifier model. + + Args: + threshold (Optional[float], optional): Distance metric threshold to determine + if two contacts shoudl be evaluated as duplicates. A threshold considering more + pairs (higher if Euclidean, lower if Cosine) will consider more possible duplicates + at the cost of performance. See note above. Defaults to None. + classifier_threshold (Optional[float], optional): Float between 0 - 1. + If classification score is above `classifier_threshold`, consider + pairs a duplicate. If None, uses the training value for the model. + Defaults to None. + n_closest (Optional[int], optional): The number of closest pairs to evaluate per + contact. If none is provided, will look in the `N_CLOSEST` environmental variable. + If none is provided or found, defaults to 2. + batch_size (Optional[int], optional): Number of contacts the model process + at once. Raising this will increase speed, with diminishing returns + based on your RAM/VRAM. If your program crashes, try lowering this. + If none is provided, will look in the `BATCH_SIZE` environmental variable. + Defaults to 16. + search_pool (Optional[str], optional): The pool of contacts to use as possible + duplicate candidates. If absent, will look in the `SEARCH_POOL` environmental + variable If none is found, then the algorithm will use all encoded contacts as + duplicate candidates. Defaults to None. + source_pool (Optional[str], optional): The pool of contacts to search to find their + duplicates. If absent, will look in the `SOURCE_POOL` environmental variable. + If none is found, then the algorithm will search for duplicates of all encoded + contacts. Defaults to None. + n_trees (Optional[int], optional): A higher value gives more precision when finding + duplicate candidates. See https://github.com/spotify/annoy#full-python-api. + If none is provided, will look in the `N_TREES` environmental variable. + If none is found, defaults to 10. + search_k (Optional[int], optional): Configures the `search_k` value in the nearest + neighbor search. See https://github.com/spotify/annoy#full-python-api. + If none is provided, will look in the `SEARCH_K` environmental variable. + If none is found, defaults to -1. + encoder_url (Optional[str], optional): URL where the encoder model can be downloaded. + If absent, will look in the `ENCODER_URL` environmental variable. + Defaults to None. + encoder_path (Optional[str], optional): Path to an existing encoder model file. + If absent, will look in the `ENCODER_PATH` environmental variable. If no + `encoder_path` is provided, will download from the provided encoder_url. + Should be an absolute path. Defaults to None. + classifier_url (Optional[str], optional): URL where the classifier model can be + downloaded. If absent, will look in the `CLASSIFIER_URL` environmental variable. + Defaults to None. + classifier_path (Optional[str], optional): Path to an existing classifier model file. + If absent, will look in the `CLASSIFIER_PATH` environmental variable. If no + `classifier_path` is provided, will download from the provided classifier_url. + Should be an absolute path. Defaults to None. + """ + dup_candidate_table = table_from_full_path(f"{self.schema}.idr_candidates") + + if threshold is None and os.getenv("THRESHOLD"): + threshold = float(os.environ["THRESHOLD"]) + + if classifier_threshold is None and os.getenv("CLASSIFIER_THRESHOLD"): + classifier_threshold = float(os.environ["CLASSIFIER_THRESHOLD"]) + + n_closest = ( + n_closest if n_closest is not None else int(os.getenv("N_CLOSEST", 2)) + ) + n_trees = n_trees if n_trees is not None else int(os.getenv("N_TREES", 10)) + search_k = search_k if search_k is not None else int(os.getenv("SEARCH_K", -1)) + batch_size = ( + batch_size if batch_size is not None else int(os.getenv("BATCH_SIZE", 16)) + ) + search_pool = search_pool or os.getenv("SEARCH_POOL") + source_pool = source_pool or os.getenv("SOURCE_POOL") + + encoder_url = ( + encoder_url if encoder_url is not None else os.getenv("ENCODER_URL") + ) + encoder_path = ( + encoder_path if encoder_path is not None else os.getenv("ENCODER_PATH") + ) + + if encoder_path is None and encoder_url is not None: + encoder_path = IDRT.model_path("encoder.pt") + logger.debug( + f"Downloading encoder model from {encoder_url} to {encoder_path}" + ) + download_model(encoder_url, encoder_path) + elif encoder_path is None: + raise RuntimeError("Must supply encoder_path or encoder_url") + + classifier_url = ( + classifier_url + if classifier_url is not None + else os.getenv("CLASSIFIER_URL") + ) + classifier_path = ( + classifier_path + if classifier_path is not None + else os.getenv("CLASSIFIER_PATH") + ) + + if classifier_path is None and classifier_url is not None: + classifier_path = IDRT.model_path("classifier.pt") + logger.debug("Downloading classifier model") + download_model(classifier_url, classifier_path) + elif classifier_path is None: + raise RuntimeError("Must supply classifier_path or classifier_url") + + step_2_run_search( + self.db, + encoder_path=encoder_path, + classifier_path=classifier_path, + source_table=self.encodings_table, + tokens_table=self.tokens_table, + dup_candidate_table=dup_candidate_table, + dup_output_table=self.duplicates_table, + threshold=threshold, + classifier_threshold=classifier_threshold, + source_pool=source_pool, + search_pool=search_pool, + n_trees=n_trees, + n_closest=n_closest, + search_k=search_k, + batch_size=batch_size, + enable_progress_bar=self.enable_progress_bar, + ) diff --git a/parsons/identity_resolution/idrt_db_adapter.py b/parsons/identity_resolution/idrt_db_adapter.py new file mode 100644 index 0000000000..f7d91933b8 --- /dev/null +++ b/parsons/identity_resolution/idrt_db_adapter.py @@ -0,0 +1,46 @@ +from typing import Optional, Any +from idrt.algorithm.database_adapter import DatabaseAdapter, EtlTable + +from parsons import Table +from parsons.databases.database_connector import DatabaseConnector +from parsons.databases.redshift import Redshift + + +def has_method(obj: Any, method_name: str) -> bool: + return callable(getattr(obj, method_name)) + + +class ParsonsDBAdapter(DatabaseAdapter): + """Provide access to use any Parsons database connector in the IDRT algorithm.""" + + def __init__(self, db: DatabaseConnector): + """Create the database adapter. + + Args: + db (DatabaseConnector): The Parsons DatabaseConnector you wish to use. + Must support the upsert method! + """ + if not has_method(db, "upsert"): + raise RuntimeError( + f"DatabaseConnector instance {type(db)} does not support upsert" + ) + self._db = db + + def _table_exists(self, tablename: str) -> bool: + return self._db.table_exists(tablename) + + def _execute_query(self, query: str) -> Optional[EtlTable]: + result = self._db.query(query) + if result is None: + return result + return result.to_petl() + + def _upsert(self, tablename: str, data: EtlTable, primary_key: Any): + if isinstance(self._db, Redshift): + self._db.upsert(Table(data), tablename, primary_key, vacuum=False) + else: + # We checked in the constructor that _db supports the upsert operation. + self._db.upsert(Table(data), tablename, primary_key) # type: ignore + + def _bulk_upload(self, tablename: str, data: EtlTable): + self._db.copy(Table(data), tablename, if_exists="drop") diff --git a/setup.py b/setup.py index 23239a3198..dca171e993 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,11 @@ def main(): limited_deps = os.environ.get("PARSONS_LIMITED_DEPENDENCIES", "") + # Put extras in here that are NOT included in requirements.txt, + # and can only be installed by explicit extension. + only_extras = { + "idr": ["idrt[algorithm]"], + } if limited_deps.strip().upper() in ("1", "YES", "TRUE", "ON"): install_requires = [ "petl", @@ -47,6 +52,7 @@ def main(): "targetsmart": ["xmltodict"], "twilio": ["twilio"], "zoom": ["PyJWT"], + **only_extras, } extras_require["all"] = sorted( {lib for libs in extras_require.values() for lib in libs} @@ -56,7 +62,7 @@ def main(): with open(os.path.join(THIS_DIR, "requirements.txt")) as reqs: install_requires = reqs.read().strip().split("\n") # No op for forward-compatibility - extras_require = {"all": []} + extras_require = {"all": [], **only_extras} setup( name="parsons",