diff --git a/microsetta_private_api/repo/metadata_repo/_repo.py b/microsetta_private_api/repo/metadata_repo/_repo.py index 03df06810..995cfaae5 100644 --- a/microsetta_private_api/repo/metadata_repo/_repo.py +++ b/microsetta_private_api/repo/metadata_repo/_repo.py @@ -74,7 +74,9 @@ def drop_private_columns(df): # sensitive in nature pm_remove = {c.lower() for c in df.columns if c.lower().startswith('pm_')} - remove = pm_remove | {c.lower() for c in EBI_REMOVE} + freetext_fields = {c.lower() for c in _get_freetext_fields()} + + remove = pm_remove | {c.lower() for c in EBI_REMOVE} | freetext_fields to_drop = [c for c in df.columns if c.lower() in remove] return df.drop(columns=to_drop, inplace=False) @@ -634,3 +636,24 @@ def _find_duplicates(barcodes): } return dups, error + + +def _get_freetext_fields(): + """ Retrieve a list of all free-text survey fields from the database + + Returns + ------- + list of str + The question_shortname values for all free-text survey questions + """ + with Transaction() as t: + with t.cursor() as cur: + cur.execute( + "SELECT sq.question_shortname " + "FROM ag.survey_question sq " + "INNER JOIN ag.survey_question_response_type sqrtype " + "ON sq.survey_question_id = sqrtype.survey_question_id " + "WHERE survey_response_type IN ('TEXT', 'STRING')" + ) + rows = cur.fetchall() + return [x[0] for x in rows] diff --git a/microsetta_private_api/repo/metadata_repo/tests/test_repo.py b/microsetta_private_api/repo/metadata_repo/tests/test_repo.py index f137a1419..81311ba10 100644 --- a/microsetta_private_api/repo/metadata_repo/tests/test_repo.py +++ b/microsetta_private_api/repo/metadata_repo/tests/test_repo.py @@ -17,10 +17,13 @@ _fetch_observed_survey_templates, _construct_multiselect_map, _find_best_answers, - drop_private_columns) + drop_private_columns, + _get_freetext_fields, + EBI_REMOVE) from microsetta_private_api.repo.survey_template_repo import SurveyTemplateRepo from microsetta_private_api.model.account import Account from microsetta_private_api.model.address import Address +from microsetta_private_api.repo.transaction import Transaction class MM: @@ -329,6 +332,32 @@ def test_drop_private_columns(self): obs = drop_private_columns(df) pdt.assert_frame_equal(obs, exp) + def test_drop_private_columns_freetext(self): + # This test specifically asserts that the new code to drop free-text + # fields works, even if those fields are not represented in the + # EBI_REMOVE list + + # First, assert that ALL_ROOMMATES is not in EBI_REMOVE + self.assertFalse("ALL_ROOMMATES" in EBI_REMOVE) + + # Next, assert that ALL_ROOMMATES is a free-text field + freetext_fields = _get_freetext_fields() + self.assertTrue("ALL_ROOMMATES" in freetext_fields) + + # Now, set up a test dataframe, based on the existing + # test_drop_private_columns df, but with the ALL_ROOMMATES field added + df = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], + columns=[ + 'pM_foo', + 'okay', + 'ABOUT_yourSELF_TEXT', + 'ALL_ROOMMATES']) + + # We only expect the "okay" column to remain + exp = pd.DataFrame([[2, ], [6, ]], columns=['okay']) + obs = drop_private_columns(df) + pdt.assert_frame_equal(obs, exp) + def test_build_col_name(self): tests_and_expected = [('foo', 'bar', 'foo_bar'), ('foo', 'bar baz', 'foo_bar_baz')] @@ -512,6 +541,30 @@ def test_find_best_answers(self): with self.assertRaises(KeyError): _ = obs[0]['response']['111'] + def test_get_freetext_fields(self): + with Transaction() as t: + with t.cursor() as cur: + # Grab the count for the number of free-text fields that exist + # in the database + cur.execute( + "SELECT COUNT(*) " + "FROM ag.survey_question_response_type " + "WHERE survey_response_type IN ('TEXT', 'STRING')" + ) + row = cur.fetchone() + freetext_count = row[0] + + # Use the _get_freetext_fields() function to pull the actual list + freetext_fields = _get_freetext_fields() + + # Assert that the field count matches + self.assertEqual(len(freetext_fields), freetext_count) + + # Assert that a few known free-text fields exist in the list + self.assertTrue("ABOUT_YOURSELF_TEXT" in freetext_fields) + self.assertTrue("ALL_ROOMMATES" in freetext_fields) + self.assertTrue("DIET_RESTRICTIONS" in freetext_fields) + if __name__ == '__main__': unittest.main()