123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868 |
- #!/usr/bin/env python
-
- """
-
- """
- from __future__ import division
-
- import logging
- import sys
- import time
- from collections import deque
- from multiprocessing import Pool
-
- import click as ck
- import numpy as np
- import pandas as pd
- import tensorflow as tf
- from keras import backend as K
- from keras.callbacks import EarlyStopping, ModelCheckpoint
- from keras.layers import (
- Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D, AveragePooling1D, Multiply,GaussianNoise,
- Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout)
- from keras.losses import binary_crossentropy
- from keras.models import Sequential, Model, load_model
- from keras.preprocessing import sequence
- from scipy.spatial import distance
- from sklearn.metrics import log_loss
- from sklearn.metrics import roc_curve, auc, matthews_corrcoef
- from keras.layers import Lambda, LeakyReLU, RepeatVector
- from sklearn.metrics import precision_recall_curve
- from keras.callbacks import TensorBoard
-
- from utils import (
- get_gene_ontology,
- get_go_set,
- get_anchestors,
- get_parents,
- DataGenerator,
- FUNC_DICT,
- get_height)
- from conditional_wgan_wrapper_exp1 import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new
-
- config = tf.ConfigProto()
- config.gpu_options.allow_growth = True
- sess = tf.Session(config=config)
- K.set_session(sess)
-
- logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
- sys.setrecursionlimit(100000)
-
- DATA_ROOT = 'Data/data1/'
- MAXLEN = 1000
- REPLEN = 256
- ind = 0
-
-
- @ck.command()
- @ck.option(
- '--function',
- default='bp',
- help='Ontology id (mf, bp, cc)')
- @ck.option(
- '--device',
- default='gpu:0',
- help='GPU or CPU device id')
- @ck.option(
- '--org',
- default= None,
- help='Organism id for filtering test set')
- @ck.option('--train', default=True, is_flag=True)
-
- def main(function, device, org, train):
- global FUNCTION
- FUNCTION = function
- global GO_ID
- GO_ID = FUNC_DICT[FUNCTION]
- global go
- go = get_gene_ontology('go.obo')
- global ORG
- ORG = org
- func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl')
- global functions
- functions = func_df['functions'].values
- global embedds
- embedds = load_go_embedding(functions)
- global all_functions
- all_functions = get_go_set(go, GO_ID)
- logging.info('Functions: %s %d' % (FUNCTION, len(functions)))
- if ORG is not None:
- logging.info('Organism %s' % ORG)
- global node_names
- node_names = set()
- with tf.device('/' + device):
- params = {
- 'fc_output': 1024,
- 'embedding_dims': 128,
- 'embedding_dropout': 0.2,
- 'nb_conv': 2,
- 'nb_dense': 1,
- 'filter_length': 128,
- 'nb_filter': 32,
- 'pool_length': 64,
- 'stride': 32
- }
- model(params, is_train=train)
-
-
- def load_go_embedding(functions):
- f_embedding = open('data/swiss/term-graph-cellular-component-100d.emb', 'r')
- f_index = open('data/swiss/term-id-cellular-component.txt','r')
- embeddings = f_embedding.readlines()
- index = f_index.readlines()
- dict1 ={}
- for i in range(1, len(index)):
- ind = map(str, index[i].split())
- dict1[ind[0]] = ind[2]
- dict2={}
- for i in range(1, len(embeddings)):
- embedding = map(str, embeddings[i].split())
- dict2[str(embedding[0])] = embedding[1:]
-
- embedds = []
-
- z = 0
- for i in range(len(functions)):
- term = functions[i]
- if term in dict1.keys():
- term_index = dict1[term]
- term_embedd = dict2[str(term_index)]
- else:
- term_embedd = np.ones([100])
- embedds.append(term_embedd)
- return np.array(embedds)
-
- def load_data():
-
- df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
- n = len(df)
- index = df.index.values
- valid_n = int(n * 0.8)
- train_df = df.loc[index[:valid_n]]
- valid_df = df.loc[index[valid_n:]]
- test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
- print( test_df['orgs'] )
- if ORG is not None:
- logging.info('Unfiltered test size: %d' % len(test_df))
- test_df = test_df[test_df['orgs'] == ORG]
- logging.info('Filtered test size: %d' % len(test_df))
-
- def reshape(values):
- values = np.hstack(values).reshape(
- len(values), len(values[0]))
- return values
-
- def normalize_minmax(values):
- mn = np.min(values)
- mx = np.max(values)
- if mx - mn != 0.0:
- return (values - mn) / (mx - mn)
- return values - mn
-
- def get_values(data_frame):
- print(data_frame['labels'].values.shape)
- labels = reshape(data_frame['labels'].values)
- ngrams = sequence.pad_sequences(
- data_frame['ngrams'].values, maxlen=MAXLEN)
- ngrams = reshape(ngrams)
- rep = reshape(data_frame['embeddings'].values)
- data = ngrams
- return data, labels
-
- train = get_values(train_df)
- valid = get_values(valid_df)
- test = get_values(test_df)
-
- return train, valid, test, train_df, valid_df, test_df
-
-
- def load_data0():
-
- df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
- n = len(df)
- index = df.index.values
- valid_n = int(n * 0.8)
- train_df = df.loc[index[:valid_n]]
- valid_df = df.loc[index[valid_n:]]
- test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
- print( test_df['orgs'] )
- if ORG is not None:
- logging.info('Unfiltered test size: %d' % len(test_df))
- test_df = test_df[test_df['orgs'] == ORG]
- logging.info('Filtered test size: %d' % len(test_df))
-
- def reshape(values):
- values = np.hstack(values).reshape(
- len(values), len(values[0]))
- return values
-
- def normalize_minmax(values):
- mn = np.min(values)
- mx = np.max(values)
- if mx - mn != 0.0:
- return (values - mn) / (mx - mn)
- return values - mn
-
- def get_values(data_frame, p_list):
- accesion_numbers = data_frame['accessions'].values
- all_labels = data_frame['labels'].values
- all_ngrams = reshape(sequence.pad_sequences(data_frame['ngrams'].values, maxlen=MAXLEN) )
- print(len(all_labels))
- print(len(p_list))
- labels = []
- data = []
- for i in range(accesion_numbers.shape[0]):
- if accesion_numbers[i] in p_list:
- labels.append(all_labels[i])
- data.append(all_ngrams[i])
-
- return np.asarray(data), np.asarray(labels)
-
- dev_F = pd.read_csv('dev-' + FUNCTION + '-input.tsv', sep='\t', header=0, index_col=0)
- test_F = pd.read_csv('test-' + FUNCTION + '-input.tsv', sep='\t', header=0, index_col=0)
- train_F = pd.read_csv('train-' + FUNCTION + '-input.tsv', sep='\t', header=0, index_col=0)
-
- train_l = train_F.index.tolist()
- test_l = test_F.index.tolist()
- dev_l = dev_F.index.tolist()
-
- train = get_values(train_df, train_l)
- valid = get_values(valid_df, dev_l)
- test = get_values(test_df, test_l)
-
- return train, valid, test, train_df, valid_df, test_df
-
-
-
-
-
- class TrainValTensorBoard(TensorBoard):
- def __init__(self, log_dir='./t_logs', **kwargs):
- # Make the original `TensorBoard` log to a subdirectory 'training'
- training_log_dir = log_dir
- super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
-
- # Log the validation metrics to a separate subdirectory
- self.val_log_dir = './v_logs'
- self.dis_log_dir = './d_logs'
- self.gen_log_dir = './g_logs'
-
- def set_model(self, model):
- # Setup writer for validation metrics
- self.val_writer = tf.summary.FileWriter(self.val_log_dir)
- self.dis_writer = tf.summary.FileWriter(self.dis_log_dir)
- self.gen_writer = tf.summary.FileWriter(self.gen_log_dir)
- super(TrainValTensorBoard, self).set_model(model)
-
- def on_epoch_end(self, epoch, logs=None):
- # Pop the validation logs and handle them separately with
- # `self.val_writer`. Also rename the keys so that they can
- # be plotted on the same figure with the training metrics
- logs = logs or {}
- val_logs = {k.replace('val_', 'v_'): v for k, v in logs.items() if k.startswith('val_')}
- for name, value in val_logs.items():
- summary = tf.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value.item()
- summary_value.tag = name
- self.val_writer.add_summary(summary, epoch)
- self.val_writer.flush()
-
-
- logs = logs or {}
- dis_logs = {k.replace('discriminator_', 'd_'): v for k, v in logs.items() if k.startswith('discriminator_')}
- for name, value in dis_logs.items():
- summary = tf.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value.item()
- summary_value.tag = name
- self.dis_writer.add_summary(summary, epoch)
- self.dis_writer.flush()
-
- logs = logs or {}
- gen_logs = {k.replace('generator_', 'g_'): v for k, v in logs.items() if k.startswith('generator_')}
- for name, value in gen_logs.items():
- summary = tf.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value.item()
- summary_value.tag = name
- self.gen_writer.add_summary(summary, epoch)
- self.gen_writer.flush()
-
- # Pass the remaining logs to `TensorBoard.on_epoch_end`
- t_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
- tr_logs = {k: v for k, v in t_logs.items() if not k.startswith('discriminator_')}
- tra_logs = {k: v for k, v in tr_logs.items() if not k.startswith('generator_')}
- super(TrainValTensorBoard, self).on_epoch_end(epoch, tra_logs)
-
- def on_train_end(self, logs=None):
- super(TrainValTensorBoard, self).on_train_end(logs)
- self.val_writer.close()
- self.gen_writer.close()
- self.dis_writer.close()
-
-
- def get_feature_model(params):
- embedding_dims = params['embedding_dims']
- max_features = 8001
- model = Sequential()
- model.add(Embedding(
- max_features,
- embedding_dims,
- input_length=MAXLEN))
- model.add(SpatialDropout1D(0.4))
- model.add(Conv1D(
- padding="valid",
- strides=1,
- filters=params['nb_filter'],
- kernel_size=params['filter_length']))
- model.add(LeakyReLU())
- #model.add(BatchNormalization())
- model.add(Conv1D(
- padding="valid",
- strides=1,
- filters=params['nb_filter'],
- kernel_size=params['filter_length']))
- model.add(LeakyReLU())
- # model.add(BatchNormalization()) problem in loading the best model
- model.add(AveragePooling1D(strides=params['stride'], pool_size=params['pool_length']))
- model.add(Flatten())
- return model
-
-
-
- #def get_node_name(go_id, unique=False):
- # name = go_id.split(':')[1]
- # if not unique:
- # return name
- # if name not in node_names:
- # node_names.add(name)
- # return name
- # i = 1
- # while (name + '_' + str(i)) in node_names:
- # i += 1
- # name = name + '_' + str(i)
- # node_names.add(name)
- # return name
-
-
- #def get_layers(inputs):
- # q = deque()
- # layers = {}
- # name = get_node_name(GO_ID)
- # layers[GO_ID] = {'net': inputs}
- # for node_id in go[GO_ID]['children']:
- # if node_id in func_set:
- # q.append((node_id, inputs))
- # while len(q) > 0:
- # node_id, net = q.popleft()
- # name = get_node_name(node_id)
- # net, output = get_function_node(name, inputs)
- # if node_id not in layers:
- # layers[node_id] = {'net': net, 'output': output}
- # for n_id in go[node_id]['children']:
- # if n_id in func_set and n_id not in layers:
- # ok = True
- # for p_id in get_parents(go, n_id):
- # if p_id in func_set and p_id not in layers:
- # ok = False
- # if ok:
- # q.append((n_id, net))
-
- # for node_id in functions:
- # childs = set(go[node_id]['children']).intersection(func_set)
- # if len(childs) > 0:
- # outputs = [layers[node_id]['output']]
- # for ch_id in childs:
- # outputs.append(layers[ch_id]['output'])
- # name = get_node_name(node_id) + '_max'
- # layers[node_id]['output'] = Maximum(name=name)(outputs)
- # return layers
-
- #def get_function_node(name, inputs):
- # output_name = name + '_out'
- # output = Dense(1, name=output_name, activation='sigmoid')(inputs)
- # return output, output
-
-
- def get_generator(params, n_classes):
- inputs = Input(shape=(MAXLEN,), dtype='int32', name='input1')
- feature_model = get_feature_model(params)(inputs)
- net = Dense(200)(feature_model) #bp(dis1 g1) 400 - cc 200 - mf 200
- net = BatchNormalization()(net)
- net = LeakyReLU()(net)
- #To impose DAG hierarchy directly to the model (inspired by DeepGO)
- #layers = get_layers(net)
- #output_models = []
- #for i in range(len(functions)):
- # output_models.append(layers[functions[i]]['output'])
- #net = Concatenate(axis=1)(output_models)
- output = Dense(n_classes, activation='tanh')(net) #'sigmoid'
- model = Model(inputs=inputs, outputs=output)
- return model
-
-
- def get_discriminator(params, n_classes, dropout_rate=0.5):
- inputs = Input(shape=(n_classes, ))
- #n_inputs = GaussianNoise(0.1)(inputs)
- #x = Conv1D(filters=8, kernel_size=128, padding='valid', strides=1)(n_inputs)
- #x = BatchNormalization()(x)
- #x = LeakyReLU()(x)
- #x = Conv1D(filters=8, kernel_size=128, padding='valid', strides=1)(x)
- #x = BatchNormalization()(x)
- #x = LeakyReLU()(x)
- #x = AveragePooling1D(strides=params['stride'], pool_size=params['pool_length'])(x)
- #x = Flatten()(x)
- #x = Lambda(lambda y: K.squeeze(y, 2))(x)
- inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2')
- x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2)
- x2 = SpatialDropout1D(0.4)(x2)
- x2 = Conv1D(filters=8, kernel_size=128, padding='valid', strides=1)(x2)
- x2 = BatchNormalization()(x2)
- x2 = LeakyReLU()(x2)
- #x2 = Conv1D(filters=1, kernel_size=1, padding='valid', strides=1)(x2)
- #x2 = BatchNormalization()(x2)
- #x2 = LeakyReLU()(x2)
- #x2 = AveragePooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2)
- x2 = Flatten()(x2)
-
- size = 100
- x2 = Dropout(dropout_rate)(x2)
- x2 = Dense(size)(x2)
- x2 = BatchNormalization()(x2)
- x2 = LeakyReLU()(x2)
-
- size = 100
- x = Dropout(dropout_rate)(inputs)
- x = Dense(size)(x)
- x = BatchNormalization()(x)
- x = LeakyReLU()(x)
-
- x = Concatenate(axis=1, name='merged2')([x, x2])
- layer_sizes = [30, 30, 30, 30, 30, 30]
- for size in layer_sizes:
- x = Dropout(dropout_rate)(x)
- x = Dense(size)(x)
- x = BatchNormalization()(x)
- x = LeakyReLU()(x)
-
- outputs = Dense(1)(x)
- model = Model(inputs = [inputs, inputs2], outputs=[outputs], name='Discriminator')
- return model
-
-
- def rescale_labels(labels):
- rescaled_labels = 2*(labels - 0.5)
- return rescaled_labels
-
-
- def rescale_back_labels(labels):
- rescaled_back_labels =(labels/2)+0.5
- return rescaled_back_labels
-
-
- def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10):
- generator = get_generator(params, nb_classes)
- discriminator = get_discriminator(params, nb_classes)
-
- generator_model, discriminator_model = \
- WGAN_wrapper(generator=generator,
- discriminator=discriminator,
- generator_input_shape=(MAXLEN,),
- discriminator_input_shape=(nb_classes,),
- discriminator_input_shape2 = (MAXLEN, ),
- batch_size=batch_size,
- gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT,
- embeddings=embedds)
-
- logging.info('Compilation finished')
- return generator_model, discriminator_model
-
-
- def train_wgan(generator_model, discriminator_model, batch_size, epochs,
- x_train, y_train, x_val, y_val, x_test, y_test, generator_model_path, discriminator_model_path,
- TRAINING_RATIO=10, N_WARM_UP=20):
- BATCH_SIZE = batch_size
- N_EPOCH = epochs
-
- positive_y = np.ones((batch_size, 1), dtype=np.float32)
- zero_y = positive_y * 0
- negative_y = -positive_y
- positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32)
- dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
-
- positive_full_enable_train = np.ones((len(x_train), 1), dtype=np.float32)
- positive_full_enable_val = np.ones((len(x_val), 1), dtype=np.float32)
- positive_full_enable_test = np.ones((len(x_test), 1), dtype=np.float32)
- best_validation_loss = None
-
- callback = TrainValTensorBoard(write_graph=False) #TensorBoard(log_path)
- callback.set_model(generator_model)
-
- for epoch in range(N_EPOCH):
- # np.random.shuffle(X_train)
- print("Epoch: ", epoch)
- print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE))
- discriminator_loss = []
- generator_loss = []
- minibatches_size = BATCH_SIZE * TRAINING_RATIO
-
- shuffled_indexes = np.random.permutation(x_train.shape[0])
- shuffled_indexes_2 = np.random.permutation(x_train.shape[0])
-
- for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))):
- batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size]
- batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size]
- x = x_train[batch_indexes]
- y = y_train[batch_indexes]
- y_2 = y_train[batch_indexes_2]
- x_2 = x_train[batch_indexes_2]
- if epoch < N_WARM_UP:
- for j in range(TRAINING_RATIO):
- x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
- y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
-
- generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y]))
- discriminator_loss.append(0)
- else:
- for j in range(TRAINING_RATIO):
- x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
- y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
- y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
- x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
- # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
- noise = x_batch
- discriminator_loss.append(discriminator_model.train_on_batch(
- [y_batch_2, noise, x_batch_2],
- [positive_y, negative_y, dummy_y]))
-
- generator_loss.append(generator_model.train_on_batch([x, positive_full_y], [y, positive_full_y]))
-
- predicted_y_train = generator_model.predict([x_train, positive_full_enable_train], batch_size=BATCH_SIZE)[0]
- predicted_y_val = generator_model.predict([x_val, positive_full_enable_val], batch_size=BATCH_SIZE)[0]
-
- train_loss = log_loss(rescale_back_labels(y_train), rescale_back_labels(predicted_y_train))
- val_loss = log_loss(y_val, rescale_back_labels(predicted_y_val))
- callback.on_epoch_end(epoch,
- logs={'discriminator_loss': (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1),
- 'generator_loss': (np.sum(np.asarray(generator_loss)) if generator_loss else -1),#train_loss,
- 'train_bce_loss': train_loss,
- 'val_bce_loss': val_loss})
-
- print("train bce loss: {:.4f}, validation bce loss: {:.4f}, generator loss: {:.4f}, discriminator loss: {:.4f}".format(
- train_loss, val_loss,
- (np.sum(np.asarray(generator_loss)) if generator_loss else -1) / x_train.shape[0],
- (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0]))
-
- if best_validation_loss is None or best_validation_loss > val_loss:
- print('\nEpoch %05d: improved from %0.5f,'
- ' saving model to %s and %s'
- % (epoch + 1, val_loss, generator_model_path, discriminator_model_path))
-
- best_validation_loss = val_loss
- generator_model.save(generator_model_path, overwrite=True)
- discriminator_model.save(discriminator_model_path, overwrite=True)
-
-
- def model(params, batch_size=64, nb_epoch=10000, is_train=True):
- nb_classes = len(functions)
-
- logging.info("Loading Data")
- train, val, test, train_df, valid_df, test_df = load_data()
- train_data, train_labels = train
- val_data, val_labels = val
- test_data, test_labels = test
- train_labels = rescale_labels(train_labels)
- logging.info("Training data size: %d" % len(train_data))
- logging.info("Validation data size: %d" % len(val_data))
- logging.info("Test data size: %d" % len(test_data))
- generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5'
- discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5'
-
- logging.info('Starting training the model')
- train_generator = DataGenerator(batch_size, nb_classes)
- train_generator.fit(train_data, train_labels)
- valid_generator = DataGenerator(batch_size, nb_classes)
- valid_generator.fit(val_data, val_labels)
- test_generator = DataGenerator(batch_size, nb_classes)
- test_generator.fit(test_data, test_labels)
-
- if is_train:
- generator_model, discriminator_model = get_model(params, nb_classes, batch_size)
- #Loading a pretrained model
- #generator_model.load_weights(DATA_ROOT + 'models/trained_bce/new_model_seq_' + FUNCTION + '.h5')
-
- train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch,
- x_train=train_data, y_train=train_labels , x_val=val_data, y_val=val_labels, x_test = test_data, y_test = test_labels,
- generator_model_path=generator_model_path,
- discriminator_model_path=discriminator_model_path)
-
- logging.info('Loading best model')
- model = load_model(generator_model_path,
- custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new,
- 'wasserstein_loss': wasserstein_loss})
-
- logging.info('Predicting')
- start_time = time.time()
- preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0]
- end_time = time.time()
- logging.info('%f prediction time for %f protein sequences' % (start_time-end_time, len(test_data)))
-
- preds = rescale_back_labels(preds)
- #preds = propagate_score(preds, functions)
- logging.info('Computing performance')
- f, p, r, t, preds_max = compute_performance(preds, test_labels)
- micro_roc_auc = compute_roc(preds, test_labels)
- macro_roc_auc = compute_roc_macro(preds, test_labels)
- mcc = compute_mcc(preds_max, test_labels)
- aupr, _ = compute_aupr(preds, test_labels)
- tpr_score, total = compute_tpr_score(preds_max, functions)
-
-
- m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T)
- function_centric_performance(functions, preds.T, test_labels.T, rescale_back_labels(train_labels).T, t)
-
- logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f))
- logging.info('Micro ROC AUC: \t %f ' % (micro_roc_auc, ))
- logging.info('Macro ROC AUC: \t %f ' % (macro_roc_auc,))
- logging.info('MCC: \t %f ' % (mcc, ))
- logging.info('AUPR: \t %f ' % (aupr, ))
- logging.info('TPR_Score from total: \t %f \t %f ' % (tpr_score,total))
- logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) )
- logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) )
-
- def function_centric_performance(functions, preds, labels, labels_train, threshold):
- results = []
- preds = np.round(preds, 2)
- for i in range(preds.shape[0]):
- predictions = (preds[i, :] > threshold).astype(np.int32)
- tp = np.sum(predictions * labels[i, :])
- fp = np.sum(predictions) - tp
- fn = np.sum(labels[i, :]) - tp
- if tp > 0:
- precision = tp / (1.0 * (tp + fp))
- recall = tp / (1.0 * (tp + fn))
- f = 2 * precision * recall / (precision + recall)
- else:
- if fp == 0 and fn == 0:
- precision = 1
- recall = 1
- f = 1
- else:
- precision = 0
- recall = 0
- f = 0
- f_max = f
- p_max = precision
- r_max = recall
- num_prots_train = np.sum(labels_train[i, :])
- height = get_height(go, functions[i])
- results.append([functions[i], num_prots_train, height, f_max, p_max, r_max])
- results = pd.DataFrame(results)
- results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False)
-
-
- def micro_macro_function_centric_f1(preds, labels):
- predictions = np.round(preds)
- m_tp = 0
- m_fp = 0
- m_fn = 0
- M_pr = 0
- M_rc = 0
- for i in range(preds.shape[0]):
- tp = np.sum(predictions[i, :] * labels[i, :])
- fp = np.sum(predictions[i, :]) - tp
- fn = np.sum(labels[i, :]) - tp
- m_tp += tp
- m_fp += fp
- m_fn += fn
- if tp > 0:
- pr = 1.0 * tp / (1.0 * (tp + fp))
- rc = 1.0 * tp / (1.0 * (tp + fn))
- else:
- if fp == 0 and fn == 0:
- pr = 1
- rc = 1
- else:
- pr = 0
- rc = 0
- M_pr += pr
- M_rc += rc
-
- if m_tp > 0:
- m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp))
- m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn))
- m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc)
- else:
- if m_fp == 0 and m_fn == 0:
- m_pr = 1
- m_rc = 1
- m_f1 = 1
- else:
- m_pr = 0
- m_rc = 0
- m_f1 = 0
-
- M_pr /= preds.shape[0]
- M_rc /= preds.shape[0]
- if M_pr == 0 and M_rc == 0:
- M_f1 = 0
- else:
- M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc)
-
- return m_pr, m_rc, m_f1, M_pr, M_rc, M_f1
-
-
- def compute_roc(preds, labels):
- # Compute ROC curve and ROC area for each class
- fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
- roc_auc = auc(fpr, tpr)
- return roc_auc
-
-
- def compute_roc_macro(yhat_raw, y):
- if yhat_raw.shape[0] <= 1:
- return
- fpr = {}
- tpr = {}
- roc_auc = {}
- # get AUC for each label individually
- relevant_labels = []
- auc_labels = {}
- for i in range(y.shape[1]):
- # only if there are true positives for this label
- if y[:, i].sum() > 0:
- fpr[i], tpr[i], _ = roc_curve(y[:, i], yhat_raw[:, i])
- if len(fpr[i]) > 1 and len(tpr[i]) > 1:
- auc_score = auc(fpr[i], tpr[i])
- if not np.isnan(auc_score):
- auc_labels["auc_%d" % i] = auc_score
- relevant_labels.append(i)
-
- # macro-AUC: just average the auc scores
- aucs = []
- for i in relevant_labels:
- aucs.append(auc_labels['auc_%d' % i])
- roc_auc_macro = np.mean(aucs)
- return roc_auc_macro
-
-
-
- def compute_aupr(preds, labels):
- # Compute ROC curve and ROC area for each class
- pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten())
- pr_auc = auc(rc, pr)
- M_pr_auc = 0
- return pr_auc, M_pr_auc
-
- def compute_mcc(preds, labels):
- # Compute ROC curve and ROC area for each class
- mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
- return mcc
-
-
- def compute_performance(preds, labels):
- predicts = np.round(preds, 2)
- f_max = 0
- p_max = 0
- r_max = 0
- t_max = 0
- for t in range(1, 100):
- threshold = t / 100.0
- predictions = (predicts > threshold).astype(np.int32)
- total = 0
- f = 0.0
- p = 0.0
- r = 0.0
- p_total = 0
- for i in range(labels.shape[0]):
- tp = np.sum(predictions[i, :] * labels[i, :])
- fp = np.sum(predictions[i, :]) - tp
- fn = np.sum(labels[i, :]) - tp
- if tp == 0 and fp == 0 and fn == 0:
- continue
- total += 1
- if tp != 0:
- p_total += 1
- precision = tp / (1.0 * (tp + fp))
- recall = tp / (1.0 * (tp + fn))
- p += precision
- r += recall
- if p_total == 0:
- continue
- r /= total
- p /= p_total
- if p + r > 0:
- f = 2 * p * r / (p + r)
- if f_max < f:
- f_max = f
- p_max = p
- r_max = r
- t_max = threshold
- predictions_max = predictions
- return f_max, p_max, r_max, t_max, predictions_max
-
-
-
- def get_anchestors(go, go_id):
- go_set = set()
- q = deque()
- q.append(go_id)
- while(len(q) > 0):
- g_id = q.popleft()
- go_set.add(g_id)
- for parent_id in go[g_id]['is_a']:
- if parent_id in go:
- q.append(parent_id)
- return go_set
-
- def compute_tpr_score(preds_max, functions):
- my_dict = {}
- for i in range(len(functions)):
- my_dict[functions[i]] = i
- total =0
- anc = np.zeros([len(functions), len(functions)])
- for i in range(len(functions)):
- anchestors = list(get_anchestors(go, functions[i]))
- for j in range(len(anchestors)):
- if anchestors[j] in my_dict.keys():
- anc[i, my_dict[anchestors[j]]]=1
- total +=1
-
-
-
- tpr_score = 0
- for i in range(preds_max.shape[0]):
- for j in range(preds_max.shape[1]):
- if preds_max[i, j] == 1:
- for k in range(anc.shape[1]):
- if anc[j, k]==1:
- if preds_max[i,k]!=1:
- tpr_score = tpr_score +1
-
- return tpr_score/preds_max.shape[0], total
-
-
- def propagate_score(preds, functions):
- my_dict = {}
- for i in range(len(functions)):
- my_dict[functions[i]] = i
-
- anc = np.zeros([len(functions), len(functions)])
- for i in range(len(functions)):
- anchestors = list(get_anchestors(go, functions[i]))
- for j in range(len(anchestors)):
- if anchestors[j] in my_dict.keys():
- anc[i, my_dict[anchestors[j]]] = 1
-
- new_preds = np.array(preds)
- for i in range(new_preds.shape[1]):
- for k in range(anc.shape[1]):
- if anc[i, k] == 1:
- new_preds[new_preds[:, i] > new_preds[:, k], k] = new_preds[new_preds[:, i] > new_preds[:, k], i]
-
- return new_preds
-
-
- if __name__ == '__main__':
- main()
|