You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

annotator.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/env python
  2. import click as ck
  3. import numpy as np
  4. import pandas as pd
  5. from keras.models import load_model
  6. from aaindex import INVALID_ACIDS
  7. MAXLEN = 1002
  8. @ck.command()
  9. @ck.option('--in-file', '-i', help='Input FASTA file', required=True)
  10. @ck.option('--threshold', '-t', default=0.3, help='Prediction threshold')
  11. @ck.option('--batch-size', '-bs', default=1, help='Batch size for prediction model')
  12. @ck.option('--include-long-seq', '-ils', is_flag=True, help='Include long sequences')
  13. @ck.option('--onto', '-go', default='bp', help='Sub-ontology to be predicted: bp, mf, cc')
  14. def main(in_file, threshold, batch_size, include_long_seq, onto):
  15. out_file = onto +'_results.txt'
  16. chunk_size = 1000
  17. global model
  18. global functions
  19. ngram_df = pd.read_pickle('data/models/ngrams.pkl')
  20. global vocab
  21. vocab = {}
  22. global gram_len
  23. for key, gram in enumerate(ngram_df['ngrams']):
  24. vocab[gram] = key + 1
  25. gram_len = len(ngram_df['ngrams'][0])
  26. print(('Gram length:', gram_len))
  27. print(('Vocabulary size:', len(vocab)))
  28. model = load_model('Data/data1/models/model_%s.h5' % onto)
  29. df = pd.read_pickle('Data/data1/models/%s.pkl' % onto)
  30. functions = df['functions']
  31. w = open(out_file, 'w')
  32. for ids, sequences in read_fasta(in_file, chunk_size, include_long_seq):
  33. data = get_data(sequences)
  34. results = predict(data, model, threshold, batch_size)
  35. for i in range(len(ids)):
  36. w.write(ids[i])
  37. w.write('\n')
  38. for res in results[i]:
  39. w.write(res)
  40. w.write('\n')
  41. w.close()
  42. def is_ok(seq):
  43. for c in seq:
  44. if c in INVALID_ACIDS:
  45. return False
  46. return True
  47. def read_fasta(filename, chunk_size, include_long_seq):
  48. seqs = list()
  49. info = list()
  50. seq = ''
  51. inf = ''
  52. with open(filename) as f:
  53. for line in f:
  54. line = line.strip()
  55. if line.startswith('>'):
  56. if seq != '':
  57. if is_ok(seq):
  58. if include_long_seq:
  59. seqs.append(seq)
  60. info.append(inf)
  61. if len(info) == chunk_size:
  62. yield (info, seqs)
  63. seqs = list()
  64. info = list()
  65. elif len(seq) <= MAXLEN:
  66. seqs.append(seq)
  67. info.append(inf)
  68. if len(info) == chunk_size:
  69. yield (info, seqs)
  70. seqs = list()
  71. info = list()
  72. else:
  73. print(('Ignoring sequence {} because its length > 1002'
  74. .format(inf)))
  75. else:
  76. print(('Ignoring sequence {} because of ambigious AA'
  77. .format(inf)))
  78. seq = ''
  79. inf = line[1:].split()[0]
  80. else:
  81. seq += line
  82. seqs.append(seq)
  83. info.append(inf)
  84. yield (info, seqs)
  85. def get_data(sequences):
  86. n = len(sequences)
  87. data = np.zeros((n, 1000), dtype=np.float32)
  88. for i in range(len(sequences)):
  89. seq = sequences[i]
  90. for j in range(min(MAXLEN, len(seq)) - gram_len + 1):
  91. data[i, j] = vocab[seq[j: (j + gram_len)]]
  92. return data
  93. def predict(data, model, threshold, batch_size):
  94. n = data.shape[0]
  95. result = list()
  96. for i in range(n):
  97. result.append(list())
  98. predictions = model.predict(data, batch_size=batch_size, verbose=1)
  99. for i in range(n):
  100. pred = (predictions[i] >= threshold).astype('int32')
  101. for j in range(len(functions)):
  102. if pred[j] == 1:
  103. result[i].append(functions[j] + ' with score ' + '%.2f' % predictions[i][j])
  104. return result
  105. if __name__ == '__main__':
  106. main()