You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. from collections import deque
  2. from keras import backend as K
  3. from keras.callbacks import ModelCheckpoint
  4. import warnings
  5. import pandas as pd
  6. from xml.etree import ElementTree as ET
  7. BIOLOGICAL_PROCESS = 'GO:0008150'
  8. MOLECULAR_FUNCTION = 'GO:0003674'
  9. CELLULAR_COMPONENT = 'GO:0005575'
  10. FUNC_DICT = {
  11. 'cc': CELLULAR_COMPONENT,
  12. 'mf': MOLECULAR_FUNCTION,
  13. 'bp': BIOLOGICAL_PROCESS}
  14. EXP_CODES = set(['EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC'])
  15. def get_ipro():
  16. ipro = dict()
  17. tree = ET.parse('data/interpro.xml')
  18. root = tree.getroot()
  19. for child in root:
  20. if child.tag != 'interpro':
  21. continue
  22. ipro_id = child.attrib['id']
  23. name = child.find('name').text
  24. ipro[ipro_id] = {
  25. 'id': ipro_id,
  26. 'name': name,
  27. 'children': list(), 'parents': list()}
  28. parents = child.find('parent_list')
  29. if parents:
  30. for parent in parents:
  31. ipro[ipro_id]['parents'].append(parent.attrib['ipr_ref'])
  32. children = child.find('child_list')
  33. if children:
  34. for ch in children:
  35. ipro[ipro_id]['children'].append(ch.attrib['ipr_ref'])
  36. return ipro
  37. def get_ipro_anchestors(ipro, ipro_id):
  38. ipro_set = set()
  39. q = deque()
  40. q.append(ipro_id)
  41. while(len(q) > 0):
  42. i_id = q.popleft()
  43. ipro_set.add(i_id)
  44. if ipro[i_id]['parents']:
  45. for parent_id in ipro[i_id]['parents']:
  46. if parent_id in ipro:
  47. q.append(parent_id)
  48. return ipro_set
  49. def get_gene_ontology(filename='go.obo'):
  50. # Reading Gene Ontology from OBO Formatted file
  51. go = dict()
  52. obj = None
  53. with open('data/' + filename, 'r') as f:
  54. for line in f:
  55. line = line.strip()
  56. if not line:
  57. continue
  58. if line == '[Term]':
  59. if obj is not None:
  60. go[obj['id']] = obj
  61. obj = dict()
  62. obj['is_a'] = list()
  63. obj['part_of'] = list()
  64. obj['regulates'] = list()
  65. obj['is_obsolete'] = False
  66. continue
  67. elif line == '[Typedef]':
  68. obj = None
  69. else:
  70. if obj is None:
  71. continue
  72. l = line.split(": ")
  73. if l[0] == 'id':
  74. obj['id'] = l[1]
  75. elif l[0] == 'is_a':
  76. obj['is_a'].append(l[1].split(' ! ')[0])
  77. elif l[0] == 'name':
  78. obj['name'] = l[1]
  79. elif l[0] == 'is_obsolete' and l[1] == 'true':
  80. obj['is_obsolete'] = True
  81. if obj is not None:
  82. go[obj['id']] = obj
  83. for go_id in go.keys():
  84. if go[go_id]['is_obsolete']:
  85. del go[go_id]
  86. for go_id, val in go.iteritems():
  87. if 'children' not in val:
  88. val['children'] = set()
  89. for p_id in val['is_a']:
  90. if p_id in go:
  91. if 'children' not in go[p_id]:
  92. go[p_id]['children'] = set()
  93. go[p_id]['children'].add(go_id)
  94. return go
  95. def get_anchestors(go, go_id):
  96. go_set = set()
  97. q = deque()
  98. q.append(go_id)
  99. while(len(q) > 0):
  100. g_id = q.popleft()
  101. go_set.add(g_id)
  102. for parent_id in go[g_id]['is_a']:
  103. if parent_id in go:
  104. q.append(parent_id)
  105. return go_set
  106. def get_parents(go, go_id):
  107. go_set = set()
  108. for parent_id in go[go_id]['is_a']:
  109. if parent_id in go:
  110. go_set.add(parent_id)
  111. return go_set
  112. def get_height(go, go_id):
  113. height_min = 100000
  114. if len(go[go_id]['is_a'])==0:
  115. height_min = 0
  116. else:
  117. for parent_id in go[go_id]['is_a']:
  118. if parent_id in go:
  119. height = get_height(go, parent_id) + 1
  120. if height < height_min:
  121. height_min = height
  122. return height_min
  123. def get_go_set(go, go_id):
  124. go_set = set()
  125. q = deque()
  126. q.append(go_id)
  127. while len(q) > 0:
  128. g_id = q.popleft()
  129. go_set.add(g_id)
  130. for ch_id in go[g_id]['children']:
  131. q.append(ch_id)
  132. return go_set
  133. def save_model_weights(model, filepath):
  134. if hasattr(model, 'flattened_layers'):
  135. # Support for legacy Sequential/Merge behavior.
  136. flattened_layers = model.flattened_layers
  137. else:
  138. flattened_layers = model.layers
  139. l_names = []
  140. w_values = []
  141. for layer in flattened_layers:
  142. layer_name = layer.name
  143. symbolic_weights = layer.weights
  144. weight_values = K.batch_get_value(symbolic_weights)
  145. if weight_values:
  146. l_names.append(layer_name)
  147. w_values.append(weight_values)
  148. df = pd.DataFrame({
  149. 'layer_names': l_names,
  150. 'weight_values': w_values})
  151. df.to_pickle(filepath)
  152. def load_model_weights(model, filepath):
  153. ''' Name-based weight loading
  154. Layers that have no matching name are skipped.
  155. '''
  156. if hasattr(model, 'flattened_layers'):
  157. # Support for legacy Sequential/Merge behavior.
  158. flattened_layers = model.flattened_layers
  159. else:
  160. flattened_layers = model.layers
  161. df = pd.read_pickle(filepath)
  162. # Reverse index of layer name to list of layers with name.
  163. index = {}
  164. for layer in flattened_layers:
  165. if layer.name:
  166. index[layer.name] = layer
  167. # We batch weight value assignments in a single backend call
  168. # which provides a speedup in TensorFlow.
  169. weight_value_tuples = []
  170. for row in df.iterrows():
  171. row = row[1]
  172. name = row['layer_names']
  173. weight_values = row['weight_values']
  174. if name in index:
  175. symbolic_weights = index[name].weights
  176. if len(weight_values) != len(symbolic_weights):
  177. raise Exception('Layer named "' + layer.name +
  178. '") expects ' + str(len(symbolic_weights)) +
  179. ' weight(s), but the saved weights' +
  180. ' have ' + str(len(weight_values)) +
  181. ' element(s).')
  182. # Set values.
  183. for i in range(len(weight_values)):
  184. weight_value_tuples.append(
  185. (symbolic_weights[i], weight_values[i]))
  186. K.batch_set_value(weight_value_tuples)
  187. def f_score(labels, preds):
  188. preds = K.round(preds)
  189. tp = K.sum(labels * preds)
  190. fp = K.sum(preds) - tp
  191. fn = K.sum(labels) - tp
  192. p = tp / (tp + fp)
  193. r = tp / (tp + fn)
  194. return 2 * p * r / (p + r)
  195. def filter_specific(go, gos):
  196. go_set = set()
  197. for go_id in gos:
  198. go_set.add(go_id)
  199. for go_id in gos:
  200. anchestors = get_anchestors(go, go_id)
  201. anchestors.discard(go_id)
  202. go_set -= anchestors
  203. return list(go_set)
  204. def read_fasta(lines):
  205. seqs = list()
  206. info = list()
  207. seq = ''
  208. inf = ''
  209. for line in lines:
  210. line = line.strip()
  211. if line.startswith('>'):
  212. if seq != '':
  213. seqs.append(seq)
  214. info.append(inf)
  215. seq = ''
  216. inf = line[1:]
  217. else:
  218. seq += line
  219. seqs.append(seq)
  220. info.append(inf)
  221. return info, seqs
  222. class MyCheckpoint(ModelCheckpoint):
  223. def on_epoch_end(self, epoch, logs={}):
  224. filepath = self.filepath.format(epoch=epoch, **logs)
  225. current = logs.get(self.monitor)
  226. if current is None:
  227. warnings.warn('Can save best model only with %s available, '
  228. 'skipping.' % (self.monitor), RuntimeWarning)
  229. else:
  230. if self.monitor_op(current, self.best):
  231. if self.verbose > 0:
  232. print('Epoch %05d: %s improved from %0.5f to %0.5f,'
  233. ' saving model to %s'
  234. % (epoch, self.monitor, self.best,
  235. current, filepath))
  236. self.best = current
  237. save_model_weights(self.model, filepath)
  238. else:
  239. if self.verbose > 0:
  240. print('Epoch %05d: %s did not improve' %
  241. (epoch, self.monitor))
  242. class DataGenerator(object):
  243. def __init__(self, batch_size, num_outputs):
  244. self.batch_size = batch_size
  245. self.num_outputs = num_outputs
  246. def fit(self, inputs, targets):
  247. self.start = 0
  248. self.inputs = inputs
  249. self.targets = targets
  250. self.size = len(self.inputs)
  251. if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
  252. self.size = len(self.inputs[0])
  253. self.has_targets = targets is not None
  254. def __next__(self):
  255. return self.next()
  256. def reset(self):
  257. self.start = 0
  258. def next(self):
  259. if self.start < self.size:
  260. # output = []
  261. # if self.has_targets:
  262. # labels = self.targets
  263. # for i in range(self.num_outputs):
  264. # output.append(
  265. # labels[self.start:(self.start + self.batch_size), i])
  266. if self.has_targets:
  267. labels = self.targets[self.start:(self.start + self.batch_size), :]
  268. if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
  269. res_inputs = []
  270. for inp in self.inputs:
  271. res_inputs.append(
  272. inp[self.start:(self.start + self.batch_size)])
  273. else:
  274. res_inputs = self.inputs[self.start:(
  275. self.start + self.batch_size)]
  276. self.start += self.batch_size
  277. if self.has_targets:
  278. return (res_inputs, labels)
  279. return res_inputs
  280. else:
  281. self.reset()
  282. return self.next()
  283. if __name__ == '__main__':
  284. pass
  285. get_ipro_xml()