diff --git a/examples/conv1d.py b/examples/conv1d.py new file mode 100644 index 0000000..c6dd552 --- /dev/null +++ b/examples/conv1d.py @@ -0,0 +1,21 @@ +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +from manas_cafa5.model import Model +from manas_cafa5.protein import Protein +from manas_cafa5.protein_cache import ProteinCache + +file = 'data/go_terms_train_set_maxlen500_minmembers75.tsv' + +m = Model.cnn1d([],1024) +trainingset = Model.parse_trainingset(file) +graph = Protein.build_graph('data/go-basic.obo') + +protein_cache = ProteinCache('data') +loaded = m.compile(trainingset, graph, 'P68510', protein_cache) + +checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join('ch', "ckpt_{epoch}"), + save_weights_only=True, +) +loaded.fit(epochs=5, callbacks=[checkpoint_callback]) diff --git a/manas_cafa5/model.py b/manas_cafa5/model.py new file mode 100644 index 0000000..013e35f --- /dev/null +++ b/manas_cafa5/model.py @@ -0,0 +1,68 @@ +import tensorflow as tf +import numpy as np +from tensorflow.keras.preprocessing.sequence import pad_sequences +from random import sample + +class Model: + def __init__(self, model, max_protein_length): + self.model = model + self.max_protein_length = max_protein_length + + @staticmethod + def cnn1d(layers, max_protein_length=500): + model = tf.keras.Sequential() + model.add(tf.keras.layers.Conv1D(32, kernel_size=3, activation='relu', input_shape=(max_protein_length,20))) + for layer in layers: + model.add(layer) + model.add(tf.keras.layers.Dense(2, activation='softmax')) # output layer + return Model(model, max_protein_length) + + # todo: subclass tf.keras.Model + def call(self, inputs): + return self.model(inputs) + + @staticmethod + def parse_trainingset(file): + terms = {} + with open(file, 'r', encoding='utf-8') as f: + for line in f: + [go_term,associated_uniprots] = line.rstrip().split('\t') + terms[go_term] = associated_uniprots.split(',') + return terms + + def compile(self, trainingset, graph, protein_uniprot_id, protein_cache, max_distance=1): + protein = protein_cache.load(protein_uniprot_id) + children = protein.go_terms_children(graph, max_distance) + + uniprot_ids = trainingset[protein_uniprot_id] + proteins = { uid: protein_cache.load(uid) for uid in uniprot_ids } + + x_train = [ + pad_sequences( + protein.one_hot_sequence(), + maxlen=self.max_protein_length, + padding='post', + truncating='post' + ) + for uid in uniprot_ids + ] + y_train = [ uid in children for uid in uniprot_ids ] + return LoadedModel( + model=self.model, + inputs=x_train, + outputs=y_train, + ) + +class LoadedModel: + def __init__(self, model, inputs, outputs): + self.model = model + self.inputs = inputs + self.outputs = outputs + + def fit(self, **kwargs): + kwargs = kwargs.copy() + kwargs.update({ 'inputs': self.inputs, 'outputs': self.outputs }) + return self.model.fit(**kwargs) + +#def rnn1d(x_train, y_train, x_val, y_val, layers, max_protein_length=500): +# pass diff --git a/manas_cafa5/protein.py b/manas_cafa5/protein.py index e7b9b13..874764a 100644 --- a/manas_cafa5/protein.py +++ b/manas_cafa5/protein.py @@ -1,3 +1,4 @@ +from .structure import Structure, STRUCTURE_TERMS import xml.parsers.expat as xml_parser import requests import re @@ -9,10 +10,34 @@ AMINO_ACID_INDEX = { a: AMINO_ACID_LIST.find(a) for a in AMINO_ACID_LIST } class Protein: - def __init__(self, name): + def __init__(self, name, **kwargs): self.name = name self.terms = None self.sequence = None + self.structures = None + if kwargs.get('autoload') != False: + self.load_uniprot() + + def __repr__(self): + return f'' + + @staticmethod + def from_file(name, file): + protein = Protein(name, autoload=False) + protein.load_file(file) + return protein + + @staticmethod + def from_url(name, url): + protein = Protein(name, autoload=False) + protein.load_url(url) + return protein + + @staticmethod + def from_data(name, data): + protein = Protein(name, autoload=False) + protein._apply_parsed(Protein.parse_xml(data)) + return protein def load_uniprot(self): self.load_url(f'https://rest.uniprot.org/uniprotkb/{self.name}.xml') @@ -27,6 +52,10 @@ def load_file(self, file): def _apply_parsed(self, parsed): self.terms = parsed.get('terms') + self.structures = { + key: [ Structure(term) for term in self.terms.get(key) ] + for key in STRUCTURE_TERMS + } self.sequence = parsed.get('sequence') def _fetch_xml_url(self, url): @@ -41,22 +70,42 @@ def _fetch_xml_url(self, url): f'received: {r.headers["content-type"]}')) return r.text - def go_terms(self): - if self.terms == None: - self.load_uniprot() - return self.terms.get('go') + def get_structure_types(self): + return list(self.structures.keys()) - def go_terms_children(self, graph, max_distance): + def get_structures(self, structure_type): + return self.structures.get(structure_type.lower()) or [] + + def get_term_types(self): + return list(self.terms.keys()) + + def get_terms(self, term_type): + return self.terms.get(term_type.lower()) or [] + + def get_children(self, term_type, graph, max_distance): term_set = set() for dist in range(1,max_distance+1): term_set = reduce( lambda terms, term: terms.union( networkx.descendants_at_distance(graph, term['id'], dist) ), - self.go_terms(), + self.get_terms(term_type), term_set ) - return list(term_set) + return [ + { + 'type': 'go', + 'id': term_id, + 'properties': {}, + } + for term_id in term_set + ] + + def go_terms(self): + return self.get_terms('go') + + def go_terms_children(self, graph, max_distance): + return self.get_children('go', graph, max_distance) @staticmethod def build_graph(url_or_file): @@ -75,25 +124,32 @@ def one_hot_sequence(self): def parse_xml(xml_data): cursor = { - 'terms': { 'go': [] }, + 'terms': {}, 'dbref': None, 'current_name': None, 'sequence': None, } def start_element(cursor, name, attrs): + name = name.lower() atype = attrs.get('type') + atype = atype and atype.lower() dbref = cursor.get('dbref') cursor['current_name'] = name - if dbref != None and name == 'property' and atype != None: + if dbref is not None and name == 'property' and atype is not None: dbref['properties'][atype] = attrs.get('value') - if name == 'dbReference' and atype == 'GO': + if name == 'dbreference' and atype is not None: dbref = { + 'type': atype, 'id': attrs.get('id'), 'properties': {}, } cursor['dbref'] = dbref - cursor['terms']['go'].append(dbref) + terms = cursor['terms'].get(atype) + if terms is None: + terms = [] + cursor['terms'][atype] = terms + terms.append(dbref) def end_element(cursor, name): cursor['current_name'] = None diff --git a/manas_cafa5/protein_cache.py b/manas_cafa5/protein_cache.py new file mode 100644 index 0000000..8c3e50c --- /dev/null +++ b/manas_cafa5/protein_cache.py @@ -0,0 +1,17 @@ +from .protein import Protein +import os, io + +class ProteinCache: + def __init__(self, cache_dir): + self.cache_dir = cache_dir + + def load(self, name, **kwargs): + file = os.path.join(self.cache_dir, name + '.xml') + if os.path.exists(file): + return Protein.from_file(name, file) + protein = Protein(name, autoload=False) + data = protein._fetch_xml_url(f'https://rest.uniprot.org/uniprotkb/{name}.xml') + wh = io.open(file,'w') + wh.write(data) + wh.close() + return Protein.from_data(name, data) diff --git a/manas_cafa5/structure.py b/manas_cafa5/structure.py index db3ed8c..f4aeb97 100644 --- a/manas_cafa5/structure.py +++ b/manas_cafa5/structure.py @@ -5,10 +5,26 @@ from .utils import fetch_url import numpy as np +STRUCTURE_TERMS = set([ 'pdb', 'alphafolddb' ]) + class Structure: - def __init__(self): + def __init__(self, term): + self.type = term.get('type') + self.id = term.get('id') + self.properties = term.get('properties') self.structure = None + def __repr__(self): + return f'' + + def load(self): + if self.type == 'pdb': + self.load_url(f'ftp://ftp.wwpdb.org/pub/pdb/data/structures/all/pdb/pdb{self.id}.ent.gz') + elif self.type == 'alphafolddb': + self.load_url(f'https://alphafold.ebi.ac.uk/files/AF-{self.id}-F1-model_v4.pdb') + else: + raise RuntimeError(f'default endpoint for type {self.type} not available') + def load_file(self, file, id=None): data = None if file.split('.')[-1] == 'gz': @@ -34,12 +50,13 @@ def load_url(self, url, id=None): data = str(data, 'utf-8') self.structure = parser.get_structure(id, StringIO(data)) - def load_pdb(self, name): - self.load_url(f'ftp://ftp.wwpdb.org//pub/pdb/data/structures/all/pdb/pdb{name}.ent.gz') - # https://warwick.ac.uk/fac/sci/moac/people/students/peter_cock/python/protein_contact_map/ - def contact_map(self, chain_one, chain_two, threshold): - model = self.structure[0] + def contact_map(self, chain_one, chain_two, threshold, index=0): + if self.structure is None: + raise RuntimeError(f'structure for type {self.type} not loaded. call load() first') + if len(self.structure) == 0: + raise RuntimeError('no structures available') + model = self.structure[index] c1 = model[chain_one] c2 = model[chain_two] cmap = np.zeros((len(c1), len(c2)), np.float) diff --git a/setup.py b/setup.py index c68c134..04bd66d 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ 'tensorflow >=2.12.0, <3.0.0', 'networkx >=3.1, <4.0', 'obonet >=1.0.0, <2.0.0', + 'scikit-learn >=1.2.2, <2.0.0', ], description="manas ML model for cafa5 contest", url="https://github.com/manastech/cafa5",