You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # These functions are used to manipulate first dataset from DeepGO and taken from https://github.com/bio-ontology-research-group/deepgo.
  2. from collections import deque
  3. from keras import backend as K
  4. from keras.callbacks import ModelCheckpoint
  5. import warnings
  6. import pandas as pd
  7. from xml.etree import ElementTree as ET
  8. BIOLOGICAL_PROCESS = 'GO:0008150'
  9. MOLECULAR_FUNCTION = 'GO:0003674'
  10. CELLULAR_COMPONENT = 'GO:0005575'
  11. FUNC_DICT = {
  12. 'cc': CELLULAR_COMPONENT,
  13. 'mf': MOLECULAR_FUNCTION,
  14. 'bp': BIOLOGICAL_PROCESS}
  15. EXP_CODES = set(['EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC'])
  16. def get_ipro():
  17. ipro = dict()
  18. tree = ET.parse('data/interpro.xml')
  19. root = tree.getroot()
  20. for child in root:
  21. if child.tag != 'interpro':
  22. continue
  23. ipro_id = child.attrib['id']
  24. name = child.find('name').text
  25. ipro[ipro_id] = {
  26. 'id': ipro_id,
  27. 'name': name,
  28. 'children': list(), 'parents': list()}
  29. parents = child.find('parent_list')
  30. if parents:
  31. for parent in parents:
  32. ipro[ipro_id]['parents'].append(parent.attrib['ipr_ref'])
  33. children = child.find('child_list')
  34. if children:
  35. for ch in children:
  36. ipro[ipro_id]['children'].append(ch.attrib['ipr_ref'])
  37. return ipro
  38. def get_ipro_anchestors(ipro, ipro_id):
  39. ipro_set = set()
  40. q = deque()
  41. q.append(ipro_id)
  42. while(len(q) > 0):
  43. i_id = q.popleft()
  44. ipro_set.add(i_id)
  45. if ipro[i_id]['parents']:
  46. for parent_id in ipro[i_id]['parents']:
  47. if parent_id in ipro:
  48. q.append(parent_id)
  49. return ipro_set
  50. def get_gene_ontology(filename='go.obo'):
  51. # Reading Gene Ontology from OBO Formatted file
  52. go = dict()
  53. obj = None
  54. with open('data/' + filename, 'r') as f:
  55. for line in f:
  56. line = line.strip()
  57. if not line:
  58. continue
  59. if line == '[Term]':
  60. if obj is not None:
  61. go[obj['id']] = obj
  62. obj = dict()
  63. obj['is_a'] = list()
  64. obj['part_of'] = list()
  65. obj['regulates'] = list()
  66. obj['is_obsolete'] = False
  67. continue
  68. elif line == '[Typedef]':
  69. obj = None
  70. else:
  71. if obj is None:
  72. continue
  73. l = line.split(": ")
  74. if l[0] == 'id':
  75. obj['id'] = l[1]
  76. elif l[0] == 'is_a':
  77. obj['is_a'].append(l[1].split(' ! ')[0])
  78. elif l[0] == 'name':
  79. obj['name'] = l[1]
  80. elif l[0] == 'is_obsolete' and l[1] == 'true':
  81. obj['is_obsolete'] = True
  82. if obj is not None:
  83. go[obj['id']] = obj
  84. for go_id in go.keys():
  85. if go[go_id]['is_obsolete']:
  86. del go[go_id]
  87. for go_id, val in go.iteritems():
  88. if 'children' not in val:
  89. val['children'] = set()
  90. for p_id in val['is_a']:
  91. if p_id in go:
  92. if 'children' not in go[p_id]:
  93. go[p_id]['children'] = set()
  94. go[p_id]['children'].add(go_id)
  95. return go
  96. def get_anchestors(go, go_id):
  97. go_set = set()
  98. q = deque()
  99. q.append(go_id)
  100. while(len(q) > 0):
  101. g_id = q.popleft()
  102. go_set.add(g_id)
  103. for parent_id in go[g_id]['is_a']:
  104. if parent_id in go:
  105. q.append(parent_id)
  106. return go_set
  107. def get_parents(go, go_id):
  108. go_set = set()
  109. for parent_id in go[go_id]['is_a']:
  110. if parent_id in go:
  111. go_set.add(parent_id)
  112. return go_set
  113. def get_height(go, go_id):
  114. height_min = 100000
  115. if len(go[go_id]['is_a'])==0:
  116. height_min = 0
  117. else:
  118. for parent_id in go[go_id]['is_a']:
  119. if parent_id in go:
  120. height = get_height(go, parent_id) + 1
  121. if height < height_min:
  122. height_min = height
  123. return height_min
  124. def get_go_set(go, go_id):
  125. go_set = set()
  126. q = deque()
  127. q.append(go_id)
  128. while len(q) > 0:
  129. g_id = q.popleft()
  130. go_set.add(g_id)
  131. for ch_id in go[g_id]['children']:
  132. q.append(ch_id)
  133. return go_set
  134. def save_model_weights(model, filepath):
  135. if hasattr(model, 'flattened_layers'):
  136. # Support for legacy Sequential/Merge behavior.
  137. flattened_layers = model.flattened_layers
  138. else:
  139. flattened_layers = model.layers
  140. l_names = []
  141. w_values = []
  142. for layer in flattened_layers:
  143. layer_name = layer.name
  144. symbolic_weights = layer.weights
  145. weight_values = K.batch_get_value(symbolic_weights)
  146. if weight_values:
  147. l_names.append(layer_name)
  148. w_values.append(weight_values)
  149. df = pd.DataFrame({
  150. 'layer_names': l_names,
  151. 'weight_values': w_values})
  152. df.to_pickle(filepath)
  153. def load_model_weights(model, filepath):
  154. ''' Name-based weight loading
  155. Layers that have no matching name are skipped.
  156. '''
  157. if hasattr(model, 'flattened_layers'):
  158. # Support for legacy Sequential/Merge behavior.
  159. flattened_layers = model.flattened_layers
  160. else:
  161. flattened_layers = model.layers
  162. df = pd.read_pickle(filepath)
  163. # Reverse index of layer name to list of layers with name.
  164. index = {}
  165. for layer in flattened_layers:
  166. if layer.name:
  167. index[layer.name] = layer
  168. # We batch weight value assignments in a single backend call
  169. # which provides a speedup in TensorFlow.
  170. weight_value_tuples = []
  171. for row in df.iterrows():
  172. row = row[1]
  173. name = row['layer_names']
  174. weight_values = row['weight_values']
  175. if name in index:
  176. symbolic_weights = index[name].weights
  177. if len(weight_values) != len(symbolic_weights):
  178. raise Exception('Layer named "' + layer.name +
  179. '") expects ' + str(len(symbolic_weights)) +
  180. ' weight(s), but the saved weights' +
  181. ' have ' + str(len(weight_values)) +
  182. ' element(s).')
  183. # Set values.
  184. for i in range(len(weight_values)):
  185. weight_value_tuples.append(
  186. (symbolic_weights[i], weight_values[i]))
  187. K.batch_set_value(weight_value_tuples)
  188. def f_score(labels, preds):
  189. preds = K.round(preds)
  190. tp = K.sum(labels * preds)
  191. fp = K.sum(preds) - tp
  192. fn = K.sum(labels) - tp
  193. p = tp / (tp + fp)
  194. r = tp / (tp + fn)
  195. return 2 * p * r / (p + r)
  196. def filter_specific(go, gos):
  197. go_set = set()
  198. for go_id in gos:
  199. go_set.add(go_id)
  200. for go_id in gos:
  201. anchestors = get_anchestors(go, go_id)
  202. anchestors.discard(go_id)
  203. go_set -= anchestors
  204. return list(go_set)
  205. def read_fasta(lines):
  206. seqs = list()
  207. info = list()
  208. seq = ''
  209. inf = ''
  210. for line in lines:
  211. line = line.strip()
  212. if line.startswith('>'):
  213. if seq != '':
  214. seqs.append(seq)
  215. info.append(inf)
  216. seq = ''
  217. inf = line[1:]
  218. else:
  219. seq += line
  220. seqs.append(seq)
  221. info.append(inf)
  222. return info, seqs
  223. class MyCheckpoint(ModelCheckpoint):
  224. def on_epoch_end(self, epoch, logs={}):
  225. filepath = self.filepath.format(epoch=epoch, **logs)
  226. current = logs.get(self.monitor)
  227. if current is None:
  228. warnings.warn('Can save best model only with %s available, '
  229. 'skipping.' % (self.monitor), RuntimeWarning)
  230. else:
  231. if self.monitor_op(current, self.best):
  232. if self.verbose > 0:
  233. print('Epoch %05d: %s improved from %0.5f to %0.5f,'
  234. ' saving model to %s'
  235. % (epoch, self.monitor, self.best,
  236. current, filepath))
  237. self.best = current
  238. save_model_weights(self.model, filepath)
  239. else:
  240. if self.verbose > 0:
  241. print('Epoch %05d: %s did not improve' %
  242. (epoch, self.monitor))
  243. class DataGenerator(object):
  244. def __init__(self, batch_size, num_outputs):
  245. self.batch_size = batch_size
  246. self.num_outputs = num_outputs
  247. def fit(self, inputs, targets):
  248. self.start = 0
  249. self.inputs = inputs
  250. self.targets = targets
  251. self.size = len(self.inputs)
  252. if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
  253. self.size = len(self.inputs[0])
  254. self.has_targets = targets is not None
  255. def __next__(self):
  256. return self.next()
  257. def reset(self):
  258. self.start = 0
  259. def next(self):
  260. if self.start < self.size:
  261. # output = []
  262. # if self.has_targets:
  263. # labels = self.targets
  264. # for i in range(self.num_outputs):
  265. # output.append(
  266. # labels[self.start:(self.start + self.batch_size), i])
  267. if self.has_targets:
  268. labels = self.targets[self.start:(self.start + self.batch_size), :]
  269. if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
  270. res_inputs = []
  271. for inp in self.inputs:
  272. res_inputs.append(
  273. inp[self.start:(self.start + self.batch_size)])
  274. else:
  275. res_inputs = self.inputs[self.start:(
  276. self.start + self.batch_size)]
  277. self.start += self.batch_size
  278. if self.has_targets:
  279. return (res_inputs, labels)
  280. return res_inputs
  281. else:
  282. self.reset()
  283. return self.next()
  284. if __name__ == '__main__':
  285. pass
  286. get_ipro_xml()