diff --git a/bmt/toolkit.py b/bmt/toolkit.py index fed2f44..b61da7a 100644 --- a/bmt/toolkit.py +++ b/bmt/toolkit.py @@ -441,12 +441,26 @@ def get_associations( A list of elements """ - association_elements = self.get_descendants("association") filtered_elements: List[str] = list() inverse_predicates: Optional[List[str]] = None + subject_categories_formatted = [] + object_categories_formatted = [] + predicates_formatted = [] + association_elements = self.get_descendants("association") + if subject_categories: + for sc in subject_categories: + sc_formatted = format_element(self.get_element(sc)) + subject_categories_formatted.append(sc_formatted) + if object_categories: + for oc in object_categories: + oc_formatted = format_element(self.get_element(oc)) + object_categories_formatted.append(oc_formatted) if predicates: + for pred in predicates: + pred_formatted = format_element(self.get_element(pred)) + predicates_formatted.append(pred_formatted) inverse_predicates = list() - for pred_curie in predicates: + for pred_curie in predicates_formatted: predicate = self.get_element(pred_curie) if predicate: inverse_p = self.get_inverse(predicate.name) @@ -454,7 +468,9 @@ def get_associations( inverse_predicates.append(inverse_p) inverse_predicates = self._format_all_elements(elements=inverse_predicates, formatted=True) - if subject_categories or predicates or object_categories: + + + if subject_categories_formatted or predicates_formatted or object_categories_formatted: # This feels like a bit of a brute force approach as an implementation, # but we just use the list of all association names to retrieve each # association record for filtering against the constraints? @@ -471,10 +487,10 @@ def get_associations( # Try to match associations in the forward direction if not( - self.match_association(association, subject_categories, predicates, object_categories) or + self.match_association(association, subject_categories_formatted, predicates_formatted, object_categories_formatted) or ( match_inverses and - self.match_association(association, object_categories, inverse_predicates, subject_categories) + self.match_association(association, object_categories, inverse_predicates, subject_categories_formatted) ) ): continue diff --git a/tests/unit/test_toolkit.py b/tests/unit/test_toolkit.py index cec72ce..dd976e0 100644 --- a/tests/unit/test_toolkit.py +++ b/tests/unit/test_toolkit.py @@ -33,6 +33,7 @@ def toolkit(): CHEMICAL_ENTITY_CURIE = "biolink:ChemicalEntity" GENE = "gene" GENE_CURIE = "biolink:Gene" +GENE_OR_GENE_PRODUCT_CURIE: "biolink:GeneOrGeneProduct" GENE_OR_GENE_PRODUCT = "gene or gene product" GENE_OR_GENE_PRODUCT_CURIE = "biolink:GeneOrGeneProduct" GENOMIC_ENTITY = "genomic entity" @@ -556,6 +557,17 @@ def test_get_associations_without_parameters(toolkit): "biolink:Association", "biolink:ChemicalAffectsGeneAssociation" ] + ), + ( # Q8 - Check if "biolink:Gene -- biolink:affected -> biolink:SmallMolecule" - still no direct match + [GENE_OR_GENE_PRODUCT], + ["biolink:affected_by"], + ["biolink:ChemicalEntity"], + False, # match_inverses + [], # as of Biolink Model release 3.5.4, there is no direct match for this set of SPO parameters + [ + "biolink:Association", + "biolink:ChemicalAffectsGeneAssociation" + ] ) ] ) @@ -581,6 +593,27 @@ def test_get_associations_with_parameters( assert not any([entry in associations for entry in does_not_contain]) +def test_get_associations_gene_to_chemical(toolkit): + associations = toolkit.get_associations( + subject_categories=["biolink:ChemicalEntity"], + predicates=["biolink:affects"], + object_categories=["biolink:GeneOrGeneProduct"], + # we don't bother testing the 'format' flag simply in confidence + # that the associated code is already well tested in other contexts + formatted=True + ) + assert associations + + unformatted_associations = toolkit.get_associations( + subject_categories=["chemical entity"], + predicates=["affects"], + object_categories=["gene or gene product"], + formatted=True + ) + + assert unformatted_associations + + def test_get_all_node_properties(toolkit): properties = toolkit.get_all_node_properties() assert "provided by" in properties