| from collections import deque | |||||
| from keras import backend as K | |||||
| from keras.callbacks import ModelCheckpoint | |||||
| import warnings | |||||
| import pandas as pd | |||||
| from xml.etree import ElementTree as ET | |||||
| BIOLOGICAL_PROCESS = 'GO:0008150' | |||||
| MOLECULAR_FUNCTION = 'GO:0003674' | |||||
| CELLULAR_COMPONENT = 'GO:0005575' | |||||
| FUNC_DICT = { | |||||
| 'cc': CELLULAR_COMPONENT, | |||||
| 'mf': MOLECULAR_FUNCTION, | |||||
| 'bp': BIOLOGICAL_PROCESS} | |||||
| EXP_CODES = set(['EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC']) | |||||
| def get_ipro(): | |||||
| ipro = dict() | |||||
| tree = ET.parse('data/interpro.xml') | |||||
| root = tree.getroot() | |||||
| for child in root: | |||||
| if child.tag != 'interpro': | |||||
| continue | |||||
| ipro_id = child.attrib['id'] | |||||
| name = child.find('name').text | |||||
| ipro[ipro_id] = { | |||||
| 'id': ipro_id, | |||||
| 'name': name, | |||||
| 'children': list(), 'parents': list()} | |||||
| parents = child.find('parent_list') | |||||
| if parents: | |||||
| for parent in parents: | |||||
| ipro[ipro_id]['parents'].append(parent.attrib['ipr_ref']) | |||||
| children = child.find('child_list') | |||||
| if children: | |||||
| for ch in children: | |||||
| ipro[ipro_id]['children'].append(ch.attrib['ipr_ref']) | |||||
| return ipro | |||||
| def get_ipro_anchestors(ipro, ipro_id): | |||||
| ipro_set = set() | |||||
| q = deque() | |||||
| q.append(ipro_id) | |||||
| while(len(q) > 0): | |||||
| i_id = q.popleft() | |||||
| ipro_set.add(i_id) | |||||
| if ipro[i_id]['parents']: | |||||
| for parent_id in ipro[i_id]['parents']: | |||||
| if parent_id in ipro: | |||||
| q.append(parent_id) | |||||
| return ipro_set | |||||
| def get_gene_ontology(filename='go.obo'): | |||||
| # Reading Gene Ontology from OBO Formatted file | |||||
| go = dict() | |||||
| obj = None | |||||
| with open('data/' + filename, 'r') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if not line: | |||||
| continue | |||||
| if line == '[Term]': | |||||
| if obj is not None: | |||||
| go[obj['id']] = obj | |||||
| obj = dict() | |||||
| obj['is_a'] = list() | |||||
| obj['part_of'] = list() | |||||
| obj['regulates'] = list() | |||||
| obj['is_obsolete'] = False | |||||
| continue | |||||
| elif line == '[Typedef]': | |||||
| obj = None | |||||
| else: | |||||
| if obj is None: | |||||
| continue | |||||
| l = line.split(": ") | |||||
| if l[0] == 'id': | |||||
| obj['id'] = l[1] | |||||
| elif l[0] == 'is_a': | |||||
| obj['is_a'].append(l[1].split(' ! ')[0]) | |||||
| elif l[0] == 'name': | |||||
| obj['name'] = l[1] | |||||
| elif l[0] == 'is_obsolete' and l[1] == 'true': | |||||
| obj['is_obsolete'] = True | |||||
| if obj is not None: | |||||
| go[obj['id']] = obj | |||||
| for go_id in go.keys(): | |||||
| if go[go_id]['is_obsolete']: | |||||
| del go[go_id] | |||||
| for go_id, val in go.iteritems(): | |||||
| if 'children' not in val: | |||||
| val['children'] = set() | |||||
| for p_id in val['is_a']: | |||||
| if p_id in go: | |||||
| if 'children' not in go[p_id]: | |||||
| go[p_id]['children'] = set() | |||||
| go[p_id]['children'].add(go_id) | |||||
| return go | |||||
| def get_anchestors(go, go_id): | |||||
| go_set = set() | |||||
| q = deque() | |||||
| q.append(go_id) | |||||
| while(len(q) > 0): | |||||
| g_id = q.popleft() | |||||
| go_set.add(g_id) | |||||
| for parent_id in go[g_id]['is_a']: | |||||
| if parent_id in go: | |||||
| q.append(parent_id) | |||||
| return go_set | |||||
| def get_parents(go, go_id): | |||||
| go_set = set() | |||||
| for parent_id in go[go_id]['is_a']: | |||||
| if parent_id in go: | |||||
| go_set.add(parent_id) | |||||
| return go_set | |||||
| def get_height(go, go_id): | |||||
| height_min = 100000 | |||||
| if len(go[go_id]['is_a'])==0: | |||||
| height_min = 0 | |||||
| else: | |||||
| for parent_id in go[go_id]['is_a']: | |||||
| if parent_id in go: | |||||
| height = get_height(go, parent_id) + 1 | |||||
| if height < height_min: | |||||
| height_min = height | |||||
| return height_min | |||||
| def get_go_set(go, go_id): | |||||
| go_set = set() | |||||
| q = deque() | |||||
| q.append(go_id) | |||||
| while len(q) > 0: | |||||
| g_id = q.popleft() | |||||
| go_set.add(g_id) | |||||
| for ch_id in go[g_id]['children']: | |||||
| q.append(ch_id) | |||||
| return go_set | |||||
| def save_model_weights(model, filepath): | |||||
| if hasattr(model, 'flattened_layers'): | |||||
| # Support for legacy Sequential/Merge behavior. | |||||
| flattened_layers = model.flattened_layers | |||||
| else: | |||||
| flattened_layers = model.layers | |||||
| l_names = [] | |||||
| w_values = [] | |||||
| for layer in flattened_layers: | |||||
| layer_name = layer.name | |||||
| symbolic_weights = layer.weights | |||||
| weight_values = K.batch_get_value(symbolic_weights) | |||||
| if weight_values: | |||||
| l_names.append(layer_name) | |||||
| w_values.append(weight_values) | |||||
| df = pd.DataFrame({ | |||||
| 'layer_names': l_names, | |||||
| 'weight_values': w_values}) | |||||
| df.to_pickle(filepath) | |||||
| def load_model_weights(model, filepath): | |||||
| ''' Name-based weight loading | |||||
| Layers that have no matching name are skipped. | |||||
| ''' | |||||
| if hasattr(model, 'flattened_layers'): | |||||
| # Support for legacy Sequential/Merge behavior. | |||||
| flattened_layers = model.flattened_layers | |||||
| else: | |||||
| flattened_layers = model.layers | |||||
| df = pd.read_pickle(filepath) | |||||
| # Reverse index of layer name to list of layers with name. | |||||
| index = {} | |||||
| for layer in flattened_layers: | |||||
| if layer.name: | |||||
| index[layer.name] = layer | |||||
| # We batch weight value assignments in a single backend call | |||||
| # which provides a speedup in TensorFlow. | |||||
| weight_value_tuples = [] | |||||
| for row in df.iterrows(): | |||||
| row = row[1] | |||||
| name = row['layer_names'] | |||||
| weight_values = row['weight_values'] | |||||
| if name in index: | |||||
| symbolic_weights = index[name].weights | |||||
| if len(weight_values) != len(symbolic_weights): | |||||
| raise Exception('Layer named "' + layer.name + | |||||
| '") expects ' + str(len(symbolic_weights)) + | |||||
| ' weight(s), but the saved weights' + | |||||
| ' have ' + str(len(weight_values)) + | |||||
| ' element(s).') | |||||
| # Set values. | |||||
| for i in range(len(weight_values)): | |||||
| weight_value_tuples.append( | |||||
| (symbolic_weights[i], weight_values[i])) | |||||
| K.batch_set_value(weight_value_tuples) | |||||
| def f_score(labels, preds): | |||||
| preds = K.round(preds) | |||||
| tp = K.sum(labels * preds) | |||||
| fp = K.sum(preds) - tp | |||||
| fn = K.sum(labels) - tp | |||||
| p = tp / (tp + fp) | |||||
| r = tp / (tp + fn) | |||||
| return 2 * p * r / (p + r) | |||||
| def filter_specific(go, gos): | |||||
| go_set = set() | |||||
| for go_id in gos: | |||||
| go_set.add(go_id) | |||||
| for go_id in gos: | |||||
| anchestors = get_anchestors(go, go_id) | |||||
| anchestors.discard(go_id) | |||||
| go_set -= anchestors | |||||
| return list(go_set) | |||||
| def read_fasta(lines): | |||||
| seqs = list() | |||||
| info = list() | |||||
| seq = '' | |||||
| inf = '' | |||||
| for line in lines: | |||||
| line = line.strip() | |||||
| if line.startswith('>'): | |||||
| if seq != '': | |||||
| seqs.append(seq) | |||||
| info.append(inf) | |||||
| seq = '' | |||||
| inf = line[1:] | |||||
| else: | |||||
| seq += line | |||||
| seqs.append(seq) | |||||
| info.append(inf) | |||||
| return info, seqs | |||||
| class MyCheckpoint(ModelCheckpoint): | |||||
| def on_epoch_end(self, epoch, logs={}): | |||||
| filepath = self.filepath.format(epoch=epoch, **logs) | |||||
| current = logs.get(self.monitor) | |||||
| if current is None: | |||||
| warnings.warn('Can save best model only with %s available, ' | |||||
| 'skipping.' % (self.monitor), RuntimeWarning) | |||||
| else: | |||||
| if self.monitor_op(current, self.best): | |||||
| if self.verbose > 0: | |||||
| print('Epoch %05d: %s improved from %0.5f to %0.5f,' | |||||
| ' saving model to %s' | |||||
| % (epoch, self.monitor, self.best, | |||||
| current, filepath)) | |||||
| self.best = current | |||||
| save_model_weights(self.model, filepath) | |||||
| else: | |||||
| if self.verbose > 0: | |||||
| print('Epoch %05d: %s did not improve' % | |||||
| (epoch, self.monitor)) | |||||
| class DataGenerator(object): | |||||
| def __init__(self, batch_size, num_outputs): | |||||
| self.batch_size = batch_size | |||||
| self.num_outputs = num_outputs | |||||
| def fit(self, inputs, targets): | |||||
| self.start = 0 | |||||
| self.inputs = inputs | |||||
| self.targets = targets | |||||
| self.size = len(self.inputs) | |||||
| if isinstance(self.inputs, tuple) or isinstance(self.inputs, list): | |||||
| self.size = len(self.inputs[0]) | |||||
| self.has_targets = targets is not None | |||||
| def __next__(self): | |||||
| return self.next() | |||||
| def reset(self): | |||||
| self.start = 0 | |||||
| def next(self): | |||||
| if self.start < self.size: | |||||
| # output = [] | |||||
| # if self.has_targets: | |||||
| # labels = self.targets | |||||
| # for i in range(self.num_outputs): | |||||
| # output.append( | |||||
| # labels[self.start:(self.start + self.batch_size), i]) | |||||
| if self.has_targets: | |||||
| labels = self.targets[self.start:(self.start + self.batch_size), :] | |||||
| if isinstance(self.inputs, tuple) or isinstance(self.inputs, list): | |||||
| res_inputs = [] | |||||
| for inp in self.inputs: | |||||
| res_inputs.append( | |||||
| inp[self.start:(self.start + self.batch_size)]) | |||||
| else: | |||||
| res_inputs = self.inputs[self.start:( | |||||
| self.start + self.batch_size)] | |||||
| self.start += self.batch_size | |||||
| if self.has_targets: | |||||
| return (res_inputs, labels) | |||||
| return res_inputs | |||||
| else: | |||||
| self.reset() | |||||
| return self.next() | |||||
| if __name__ == '__main__': | |||||
| pass | |||||
| get_ipro_xml() |