#!/usr/bin/env python import click as ck import numpy as np import pandas as pd from keras.models import load_model from aaindex import INVALID_ACIDS MAXLEN = 1002 @ck.command() @ck.option('--in-file', '-i', help='Input FASTA file', required=True) @ck.option('--threshold', '-t', default=0.3, help='Prediction threshold') @ck.option('--batch-size', '-bs', default=1, help='Batch size for prediction model') @ck.option('--include-long-seq', '-ils', is_flag=True, help='Include long sequences') @ck.option('--onto', '-go', default='bp', help='Sub-ontology to be predicted: bp, mf, cc') def main(in_file, threshold, batch_size, include_long_seq, onto): out_file = onto +'_results.txt' chunk_size = 1000 global model global functions ngram_df = pd.read_pickle('data/models/ngrams.pkl') global vocab vocab = {} global gram_len for key, gram in enumerate(ngram_df['ngrams']): vocab[gram] = key + 1 gram_len = len(ngram_df['ngrams'][0]) print(('Gram length:', gram_len)) print(('Vocabulary size:', len(vocab))) model = load_model('Data/data1/models/model_%s.h5' % onto) df = pd.read_pickle('Data/data1/models/%s.pkl' % onto) functions = df['functions'] w = open(out_file, 'w') for ids, sequences in read_fasta(in_file, chunk_size, include_long_seq): data = get_data(sequences) results = predict(data, model, threshold, batch_size) for i in range(len(ids)): w.write(ids[i]) w.write('\n') for res in results[i]: w.write(res) w.write('\n') w.close() def is_ok(seq): for c in seq: if c in INVALID_ACIDS: return False return True def read_fasta(filename, chunk_size, include_long_seq): seqs = list() info = list() seq = '' inf = '' with open(filename) as f: for line in f: line = line.strip() if line.startswith('>'): if seq != '': if is_ok(seq): if include_long_seq: seqs.append(seq) info.append(inf) if len(info) == chunk_size: yield (info, seqs) seqs = list() info = list() elif len(seq) <= MAXLEN: seqs.append(seq) info.append(inf) if len(info) == chunk_size: yield (info, seqs) seqs = list() info = list() else: print(('Ignoring sequence {} because its length > 1002' .format(inf))) else: print(('Ignoring sequence {} because of ambigious AA' .format(inf))) seq = '' inf = line[1:].split()[0] else: seq += line seqs.append(seq) info.append(inf) yield (info, seqs) def get_data(sequences): n = len(sequences) data = np.zeros((n, 1000), dtype=np.float32) for i in range(len(sequences)): seq = sequences[i] for j in range(min(MAXLEN, len(seq)) - gram_len + 1): data[i, j] = vocab[seq[j: (j + gram_len)]] return data def predict(data, model, threshold, batch_size): n = data.shape[0] result = list() for i in range(n): result.append(list()) predictions = model.predict(data, batch_size=batch_size, verbose=1) for i in range(n): pred = (predictions[i] >= threshold).astype('int32') for j in range(len(functions)): if pred[j] == 1: result[i].append(functions[j] + ' with score ' + '%.2f' % predictions[i][j]) return result if __name__ == '__main__': main()