Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from manas_cafa5.model import Model
from manas_cafa5.protein import Protein
from manas_cafa5.protein_cache import ProteinCache

file = 'data/go_terms_train_set_maxlen500_minmembers75.tsv'

m = Model.cnn1d([],1024)
trainingset = Model.parse_trainingset(file)
graph = Protein.build_graph('data/go-basic.obo')

protein_cache = ProteinCache('data')
loaded = m.compile(trainingset, graph, 'P68510', protein_cache)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join('ch', "ckpt_{epoch}"),
save_weights_only=True,
)
loaded.fit(epochs=5, callbacks=[checkpoint_callback])
68 changes: 68 additions & 0 deletions manas_cafa5/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
from random import sample

class Model:
def __init__(self, model, max_protein_length):
self.model = model
self.max_protein_length = max_protein_length

@staticmethod
def cnn1d(layers, max_protein_length=500):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv1D(32, kernel_size=3, activation='relu', input_shape=(max_protein_length,20)))
for layer in layers:
model.add(layer)
model.add(tf.keras.layers.Dense(2, activation='softmax')) # output layer
return Model(model, max_protein_length)

# todo: subclass tf.keras.Model
def call(self, inputs):
return self.model(inputs)

@staticmethod
def parse_trainingset(file):
terms = {}
with open(file, 'r', encoding='utf-8') as f:
for line in f:
[go_term,associated_uniprots] = line.rstrip().split('\t')
terms[go_term] = associated_uniprots.split(',')
return terms

def compile(self, trainingset, graph, protein_uniprot_id, protein_cache, max_distance=1):
protein = protein_cache.load(protein_uniprot_id)
children = protein.go_terms_children(graph, max_distance)

uniprot_ids = trainingset[protein_uniprot_id]
proteins = { uid: protein_cache.load(uid) for uid in uniprot_ids }

x_train = [
pad_sequences(
protein.one_hot_sequence(),
maxlen=self.max_protein_length,
padding='post',
truncating='post'
)
for uid in uniprot_ids
]
y_train = [ uid in children for uid in uniprot_ids ]
return LoadedModel(
model=self.model,
inputs=x_train,
outputs=y_train,
)

class LoadedModel:
def __init__(self, model, inputs, outputs):
self.model = model
self.inputs = inputs
self.outputs = outputs

def fit(self, **kwargs):
kwargs = kwargs.copy()
kwargs.update({ 'inputs': self.inputs, 'outputs': self.outputs })
return self.model.fit(**kwargs)

#def rnn1d(x_train, y_train, x_val, y_val, layers, max_protein_length=500):
# pass
80 changes: 68 additions & 12 deletions manas_cafa5/protein.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .structure import Structure, STRUCTURE_TERMS
import xml.parsers.expat as xml_parser
import requests
import re
Expand All @@ -9,10 +10,34 @@
AMINO_ACID_INDEX = { a: AMINO_ACID_LIST.find(a) for a in AMINO_ACID_LIST }

class Protein:
def __init__(self, name):
def __init__(self, name, **kwargs):
self.name = name
self.terms = None
self.sequence = None
self.structures = None
if kwargs.get('autoload') != False:
self.load_uniprot()

def __repr__(self):
return f'<manas_cafa5.Protein name={self.name}>'

@staticmethod
def from_file(name, file):
protein = Protein(name, autoload=False)
protein.load_file(file)
return protein

@staticmethod
def from_url(name, url):
protein = Protein(name, autoload=False)
protein.load_url(url)
return protein

@staticmethod
def from_data(name, data):
protein = Protein(name, autoload=False)
protein._apply_parsed(Protein.parse_xml(data))
return protein

def load_uniprot(self):
self.load_url(f'https://rest.uniprot.org/uniprotkb/{self.name}.xml')
Expand All @@ -27,6 +52,10 @@ def load_file(self, file):

def _apply_parsed(self, parsed):
self.terms = parsed.get('terms')
self.structures = {
key: [ Structure(term) for term in self.terms.get(key) ]
for key in STRUCTURE_TERMS
}
self.sequence = parsed.get('sequence')

def _fetch_xml_url(self, url):
Expand All @@ -41,22 +70,42 @@ def _fetch_xml_url(self, url):
f'received: {r.headers["content-type"]}'))
return r.text

def go_terms(self):
if self.terms == None:
self.load_uniprot()
return self.terms.get('go')
def get_structure_types(self):
return list(self.structures.keys())

def go_terms_children(self, graph, max_distance):
def get_structures(self, structure_type):
return self.structures.get(structure_type.lower()) or []

def get_term_types(self):
return list(self.terms.keys())

def get_terms(self, term_type):
return self.terms.get(term_type.lower()) or []

def get_children(self, term_type, graph, max_distance):
term_set = set()
for dist in range(1,max_distance+1):
term_set = reduce(
lambda terms, term: terms.union(
networkx.descendants_at_distance(graph, term['id'], dist)
),
self.go_terms(),
self.get_terms(term_type),
term_set
)
return list(term_set)
return [
{
'type': 'go',
'id': term_id,
'properties': {},
}
for term_id in term_set
]

def go_terms(self):
return self.get_terms('go')

def go_terms_children(self, graph, max_distance):
return self.get_children('go', graph, max_distance)

@staticmethod
def build_graph(url_or_file):
Expand All @@ -75,25 +124,32 @@ def one_hot_sequence(self):

def parse_xml(xml_data):
cursor = {
'terms': { 'go': [] },
'terms': {},
'dbref': None,
'current_name': None,
'sequence': None,
}

def start_element(cursor, name, attrs):
name = name.lower()
atype = attrs.get('type')
atype = atype and atype.lower()
dbref = cursor.get('dbref')
cursor['current_name'] = name
if dbref != None and name == 'property' and atype != None:
if dbref is not None and name == 'property' and atype is not None:
dbref['properties'][atype] = attrs.get('value')
if name == 'dbReference' and atype == 'GO':
if name == 'dbreference' and atype is not None:
dbref = {
'type': atype,
'id': attrs.get('id'),
'properties': {},
}
cursor['dbref'] = dbref
cursor['terms']['go'].append(dbref)
terms = cursor['terms'].get(atype)
if terms is None:
terms = []
cursor['terms'][atype] = terms
terms.append(dbref)

def end_element(cursor, name):
cursor['current_name'] = None
Expand Down
17 changes: 17 additions & 0 deletions manas_cafa5/protein_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .protein import Protein
import os, io

class ProteinCache:
def __init__(self, cache_dir):
self.cache_dir = cache_dir

def load(self, name, **kwargs):
file = os.path.join(self.cache_dir, name + '.xml')
if os.path.exists(file):
return Protein.from_file(name, file)
protein = Protein(name, autoload=False)
data = protein._fetch_xml_url(f'https://rest.uniprot.org/uniprotkb/{name}.xml')
wh = io.open(file,'w')
wh.write(data)
wh.close()
return Protein.from_data(name, data)
29 changes: 23 additions & 6 deletions manas_cafa5/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,26 @@
from .utils import fetch_url
import numpy as np

STRUCTURE_TERMS = set([ 'pdb', 'alphafolddb' ])

class Structure:
def __init__(self):
def __init__(self, term):
self.type = term.get('type')
self.id = term.get('id')
self.properties = term.get('properties')
self.structure = None

def __repr__(self):
return f'<manas_cafa5.Structure id={self.id} type={self.type} properties={self.properties}>'

def load(self):
if self.type == 'pdb':
self.load_url(f'ftp://ftp.wwpdb.org/pub/pdb/data/structures/all/pdb/pdb{self.id}.ent.gz')
elif self.type == 'alphafolddb':
self.load_url(f'https://alphafold.ebi.ac.uk/files/AF-{self.id}-F1-model_v4.pdb')
else:
raise RuntimeError(f'default endpoint for type {self.type} not available')

def load_file(self, file, id=None):
data = None
if file.split('.')[-1] == 'gz':
Expand All @@ -34,12 +50,13 @@ def load_url(self, url, id=None):
data = str(data, 'utf-8')
self.structure = parser.get_structure(id, StringIO(data))

def load_pdb(self, name):
self.load_url(f'ftp://ftp.wwpdb.org//pub/pdb/data/structures/all/pdb/pdb{name}.ent.gz')

# https://warwick.ac.uk/fac/sci/moac/people/students/peter_cock/python/protein_contact_map/
def contact_map(self, chain_one, chain_two, threshold):
model = self.structure[0]
def contact_map(self, chain_one, chain_two, threshold, index=0):
if self.structure is None:
raise RuntimeError(f'structure for type {self.type} not loaded. call load() first')
if len(self.structure) == 0:
raise RuntimeError('no structures available')
model = self.structure[index]
c1 = model[chain_one]
c2 = model[chain_two]
cmap = np.zeros((len(c1), len(c2)), np.float)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'tensorflow >=2.12.0, <3.0.0',
'networkx >=3.1, <4.0',
'obonet >=1.0.0, <2.0.0',
'scikit-learn >=1.2.2, <2.0.0',
],
description="manas ML model for cafa5 contest",
url="https://github.com/manastech/cafa5",
Expand Down