diff --git a/README.md b/README.md index a55a3c87d7..90e1247df8 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ Reading systems: | Geneways | [`indra.sources.geneways`](https://indra.readthedocs.io/en/latest/modules/sources/geneways/index.html) | https://www.ncbi.nlm.nih.gov/pubmed/15016385 | | GNBR | [`indra.sources.gnbr`](https://indra.readthedocs.io/en/latest/modules/sources/gnbr/index.html) | https://zenodo.org/record/3459420 | | SemRep | [`indra.sources.semrep`](https://indra.readthedocs.io/en/latest/modules/sources/semrep.html) | https://github.com/lhncbc/SemRep | +| INDRA-BERT | [`indra.sources.indra_bert`](https://indra.readthedocs.io/en/latest/modules/sources/indra_bert.html) | https://github.com/gyorilab/indra_bert | Biological pathway databases: diff --git a/doc/conf.py b/doc/conf.py index 29eab8e029..10433cee25 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -315,7 +315,7 @@ 'nltk', 'nltk.stem', 'nltk.stem.snowball', 'kappy', 'openpyxl', 'reportlab', 'reportlab.lib', 'reportlab.lib.enums', 'reportlab.lib.pagesizes', 'reportlab.platypus', 'reportlab.lib.styles', - 'reportlab.lib.units' + 'reportlab.lib.units', 'indra_bert' ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.MagicMock() diff --git a/doc/modules/sources/index.rst b/doc/modules/sources/index.rst index 411bb4872a..f8bc840597 100644 --- a/doc/modules/sources/index.rst +++ b/doc/modules/sources/index.rst @@ -22,6 +22,7 @@ Reading Systems eidos/index gnbr/index semrep + indra_bert Molecular Pathway Databases --------------------------- diff --git a/doc/modules/sources/indra_bert/index.rst b/doc/modules/sources/indra_bert/index.rst new file mode 100644 index 0000000000..784554e21a --- /dev/null +++ b/doc/modules/sources/indra_bert/index.rst @@ -0,0 +1,18 @@ +INDRA-BERT (:py:mod:`indra.sources.indra_bert`) +=============================================== + +.. automodule:: indra.sources.indra_bert + :members: + +INDRA-BERT API (:py:mod:`indra.sources.indra_bert.api`) +------------------------------------------------------- + +.. automodule:: indra.sources.indra_bert.api + :members: + +INDRA-BERT Processor (:py:mod:`indra.sources.indra_bert.processor`) +------------------------------------------------------------------- + +.. automodule:: indra.sources.indra_bert.processor + :members: + diff --git a/indra/resources/default_belief_probs.json b/indra/resources/default_belief_probs.json index 41902405da..9bf6541593 100644 --- a/indra/resources/default_belief_probs.json +++ b/indra/resources/default_belief_probs.json @@ -34,7 +34,10 @@ "creeds": 0.01, "ubibrowser": 0.01, "acsn": 0.01, - "semrep": 0.05 + "semrep": 0.05, + "wormbase": 0.01, + "indra_bert": 0.05, + "indra_gpt": 0.05 }, "rand": { "eidos": 0.3, @@ -72,6 +75,8 @@ "ubibrowser": 0.1, "acsn": 0.1, "semrep": 0.3, - "wormbase": 0.1 + "wormbase": 0.1, + "indra_bert": 0.3, + "indra_gpt": 0.3 } } diff --git a/indra/resources/source_info.json b/indra/resources/source_info.json index 30eaaca7b9..4dfe0cdcdb 100644 --- a/indra/resources/source_info.json +++ b/indra/resources/source_info.json @@ -39,6 +39,16 @@ "background-color": "#fdb462" } }, + "indra_bert": { + "name": "INDRA-BERT", + "link": "https://github.com/gyorilab/indra_bert", + "type": "reader", + "domain": "biology", + "default_style": { + "color": "white", + "background-color": "#2d9636" + } + }, "tees": { "name": "TEES", "link": "https://github.com/jbjorne/TEES", diff --git a/indra/sources/indra_bert/__init__.py b/indra/sources/indra_bert/__init__.py new file mode 100644 index 0000000000..f77d5060c0 --- /dev/null +++ b/indra/sources/indra_bert/__init__.py @@ -0,0 +1 @@ +from .api import * \ No newline at end of file diff --git a/indra/sources/indra_bert/api.py b/indra/sources/indra_bert/api.py new file mode 100644 index 0000000000..e48a7788b4 --- /dev/null +++ b/indra/sources/indra_bert/api.py @@ -0,0 +1,92 @@ +__all__ = ['process_text'] + +import os +from tqdm import tqdm +import logging + +logger = logging.getLogger(__name__) + +try: + from indra_bert import IndraStructuredExtractor +except ImportError as e: + logger.error("""Could not import indra_bert for reading with INDRA BERT. + Please make sure the indra_bert extra dependencies of + INDRA are installed.""") + raise ImportError(e) + +from .processor import IndraBertProcessor + +def create_extractor( + ner_model_path="thomaslim6793/indra_bert_ner_agent_detection", + stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier", + role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner", + mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection", + stmt_conf_threshold=0.95 +): + try: + ise = IndraStructuredExtractor( + ner_model_path=ner_model_path, + stmt_model_path=stmt_model_path, + role_model_path=role_model_path, + mutations_model_path=mutations_model_path, + stmt_conf_threshold=stmt_conf_threshold + ) + except Exception as e: + logger.info(f"Error - {e}") + logger.info("Downloading models from Hugging Face") + ise = IndraStructuredExtractor( + ner_model_path="thomaslim6793/indra_bert_ner_agent_detection", + stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier", + role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner", + mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection", + stmt_conf_threshold=stmt_conf_threshold + ) + logger.info(f"Loaded ner_model from: {ise.ner_model_local_path}") + logger.info(f"Loaded stmt_model from: {ise.stmt_model_local_path}") + logger.info(f"Loaded role_model from: {ise.role_model_local_path}") + logger.info(f"Loaded mutations_model from: {ise.mutations_model_local_path}") + return ise + +def process_text(text, + ner_model_path="thomaslim6793/indra_bert_ner_agent_detection", + stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier", + role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner", + mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection", + stmt_conf_threshold=0.95, + grounder=None): + ise = create_extractor( + ner_model_path=ner_model_path, + stmt_model_path=stmt_model_path, + role_model_path=role_model_path, + mutations_model_path=mutations_model_path, + stmt_conf_threshold=stmt_conf_threshold + ) + res = ise.get_json_indra_stmts(text) + ip = IndraBertProcessor(res, grounder=grounder) + return ip, ise + +def process_texts(texts, + ner_model_path="thomaslim6793/indra_bert_ner_agent_detection", + stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier", + role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner", + mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection", + stmt_conf_threshold=0.95, + grounder=None): + + if not isinstance(texts, list): + raise ValueError("Input must be a list of texts.") + + ise = create_extractor( + ner_model_path=ner_model_path, + stmt_model_path=stmt_model_path, + role_model_path=role_model_path, + mutations_model_path=mutations_model_path, + stmt_conf_threshold=stmt_conf_threshold + ) + + ips = [] + for text in tqdm(texts, desc="Processing texts"): + res = ise.get_json_indra_stmts(text) + ip = IndraBertProcessor(res, grounder=grounder) + ips.append(ip) + return ips, ise diff --git a/indra/sources/indra_bert/processor.py b/indra/sources/indra_bert/processor.py new file mode 100644 index 0000000000..0f89ac4f8a --- /dev/null +++ b/indra/sources/indra_bert/processor.py @@ -0,0 +1,68 @@ +from indra.statements import * +from indra.statements.io import stmt_from_json +from indra.ontology.standardize import standardize_agent_name + +import re +import logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +class IndraBertProcessor: + def __init__(self, data, grounder=None): + self.data = data + self.statements = [] + self.source_api = 'indra_bert' + self.grounder = grounder if grounder else default_grounder_wrapper + self.extract_statements() + + + def extract_statement(self, entry): + """Extract a statement from JSON using INDRA's built-in functionality.""" + try: + # Use INDRA's built-in statement_from_json functionality + stmt = stmt_from_json(entry) + + # Apply grounding to agents if grounder is available + if self.grounder: + text = entry['evidence'][0]['text'] if entry.get('evidence') else "" + self._apply_grounding(stmt, text) + + return stmt + + except Exception as e: + logger.warning(f"Error creating statement from JSON: {e}") + raise + + def _apply_grounding(self, stmt, context_text): + """Apply grounding to all agents in a statement.""" + # Get all agents from the statement + agents = stmt.agent_list() + + for agent in agents: + if agent and agent.name: + # Apply grounding + grounding_result = self.grounder(agent.name, context_text) + if grounding_result: + # Update db_refs with grounding results + agent.db_refs.update(grounding_result) + + # Standardize the agent name + standardize_agent_name(agent, standardize_refs=True) + + def extract_statements(self): + self.statements = [] + for entry in self.data: + try: + stmt = self.extract_statement(entry) + except Exception as e: + logger.warning(f"Error processing entry: {e}") + logger.debug(f"Entry data: {entry}") + continue + self.statements.append(stmt) + + +def default_grounder_wrapper(text, context=None): + # Import here to avoid this when working in INDRA World context + from indra.preassembler.grounding_mapper.gilda import get_grounding + grounding, _ = get_grounding(text, context=context, mode='local') + return grounding diff --git a/indra/util/statement_presentation.py b/indra/util/statement_presentation.py index f153e5916f..0002b39357 100644 --- a/indra/util/statement_presentation.py +++ b/indra/util/statement_presentation.py @@ -124,8 +124,9 @@ class to define a `StmtStat`. 'ubibrowser', 'acsn', 'wormbase'] """Database source names as they appear in the DB""" -reader_sources = ['geneways', 'tees', 'gnbr', 'semrep', 'isi', 'trips', - 'rlimsp', 'medscan', 'eidos', 'sparser', 'reach'] +reader_sources = ['geneways', 'tees', 'gnbr', 'semrep', 'indra_bert', + 'isi', 'trips', 'rlimsp', 'medscan', 'eidos', 'sparser', + 'reach'] """Reader source names as they appear in the DB""" # These are mappings where the actual INDRA source, as it appears diff --git a/setup.py b/setup.py index b8d3208414..08b00ac8b0 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ def main(): 'trips_offline': ['pykqml'], 'reach_offline': ['cython<3', 'pyjnius==1.1.4'], 'eidos_offline': ['cython<3', 'pyjnius==1.1.4'], + 'indra_bert': ['indra_bert'], 'hypothesis': ['gilda>1.0.0'], 'geneways': ['stemming', 'nltk<3.6'], 'bel': ['pybel>=0.15.0,<0.16.0'], @@ -90,6 +91,7 @@ def main(): 'indra.sources.geneways', 'indra.sources.gnbr', 'indra.sources.hprd', 'indra.sources.hypothesis', 'indra.sources.index_cards', + 'indra.sources.indra_bert', 'indra.sources.indra_db_rest', 'indra.sources.isi', 'indra.sources.minerva', 'indra.sources.ndex_cx', 'indra.sources.reach', 'indra.sources.omnipath',