#!/usr/bin/env python """ """ from __future__ import division import logging import sys import time import click as ck import numpy as np import pandas as pd import tensorflow as tf from keras import backend as K from keras.layers import ( Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D,GaussianNoise, Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout) from keras.models import Sequential, Model, load_model from sklearn.metrics import log_loss from sklearn.metrics import roc_curve, auc, matthews_corrcoef, mean_squared_error from keras.layers import Lambda, LeakyReLU from sklearn.metrics import precision_recall_curve from keras.callbacks import TensorBoard from skimage import measure from utils import DataGenerator from conditional_wgan_wrapper_exp2 import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) K.set_session(sess) logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO) sys.setrecursionlimit(100000) DATA_ROOT = 'Data/data2/' MAXLEN = 258 @ck.command() @ck.option( '--function', default='cc', help='Ontology id (mf, bp, cc)') @ck.option( '--device', default='gpu:0', help='GPU or CPU device id') @ck.option('--train',default =True, is_flag=True) def main(function, device, train): global FUNCTION FUNCTION = function with tf.device('/' + device): params = { 'fc_output': 1024, 'embedding_dims': 128, 'embedding_dropout': 0.2, 'nb_conv': 1, 'nb_dense': 1, 'filter_length': 128, 'nb_filter': 32, 'pool_length': 64, 'stride': 32 } model(params, is_train=train) def load_data2(): all_data_x_fn = 'data2/all_data_X.csv' all_data_x = pd.read_csv(all_data_x_fn, sep='\t', header=0, index_col=0) all_proteins_train = [p.replace('"', '') for p in all_data_x.index] all_data_x.index = all_proteins_train all_data_y_fn = 'data2/all_data_Y.csv' all_data_y = pd.read_csv(all_data_y_fn, sep='\t', header=0, index_col=0) branch = pd.read_csv('data2/'+FUNCTION +'_branches.txt', sep='\t', header=0, index_col=0) all_x = all_data_x.values branches = [p for p in branch.index.tolist() if p in all_data_y.columns.tolist()] branches = list(dict.fromkeys(branches)) global mt_funcs mt_funcs = branches t= pd.DataFrame(all_data_y, columns=branches) all_y = t.values number_of_test = int(np.ceil(0.4 * len(all_x))) np.random.seed(0) index = np.random.rand(1,number_of_test) index_test = [int(p) for p in np.ceil(index*(len(all_x)-1))[0]] index_train = [p for p in range(len(all_x)) if p not in index_test] train_data = all_x[index_train, : ] test_x = all_x[index_test, : ] train_labels = all_y[index_train, : ] test_y = all_y[index_test, :] number_of_vals = int(np.ceil(0.5 * len(test_x))) np.random.seed(1) index = np.random.rand(1, number_of_vals) index_vals = [int(p) for p in np.ceil(index * (len(test_x) - 1))[0]] index_test = [p for p in range(len(test_x)) if p not in index_vals] val_data = test_x[index_vals, :] test_data = test_x[index_test, :] val_labels = test_y[index_vals, :] test_labels = test_y[index_test, :] print(train_labels.shape) return train_data, train_labels, test_data, test_labels, val_data, val_labels def get_feature_model(params): embedding_dims = params['embedding_dims'] max_features = 8001 model = Sequential() model.add(Embedding( max_features, embedding_dims, input_length=MAXLEN)) model.add(SpatialDropout1D(0.4)) for i in range(params['nb_conv']): model.add(Conv1D( padding="valid", strides=1, filters=params['nb_filter'], kernel_size=params['filter_length'])) model.add(LeakyReLU()) model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])) model.add(Flatten()) return model def get_generator(params, n_classes): inputs = Input(shape=(MAXLEN,), dtype='float32', name='input1') net0 = Dense(200)(inputs) #net0 = BatchNormalization()(net0) net0 = LeakyReLU()(net0) net0 = Dense(200)(net0) #net0 = BatchNormalization()(net0) net0 = LeakyReLU()(net0) #net0 = Dense(200, activation='relu')(inputs) net0 = Dense(200)(net0) #net0 = BatchNormalization()(net0) net = LeakyReLU()(net0) output = Dense(n_classes, activation='tanh')(net) model = Model(inputs=inputs, outputs=output) return model def get_discriminator(params, n_classes, dropout_rate=0.5): inputs = Input(shape=(n_classes, )) n_inputs = GaussianNoise(0.1)(inputs) inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2') x2 = Embedding(8001, 128, input_length=MAXLEN)(inputs2) x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', strides=1)(x2) x2 = LeakyReLU()(x2) x2 = Lambda(lambda x: K.squeeze(x, 2))(x2) #for i in range(params['nb_conv']): # x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2) #x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2) #x2 = Flatten()(x2) size = 40 x = n_inputs x = Dropout(dropout_rate)(x) x = Dense(size)(x) x = BatchNormalization()(x) x = LeakyReLU()(x) size = 40 x2 = Dropout(dropout_rate)(x2) x2 = Dense(size)(x2) x2 = BatchNormalization()(x2) x2 = LeakyReLU()(x2) x = Concatenate(axis =1 , name = 'merged2')([x, x2]) layer_sizes = [30, 30, 30, 30, 30, 30] for size in layer_sizes: x = Dropout(dropout_rate)(x) x = Dense(size)(x) x = BatchNormalization()(x) x = LeakyReLU()(x) outputs = Dense(1)(x) model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator') return model class TrainValTensorBoard(TensorBoard): def __init__(self, log_dir='./t_logs', **kwargs): # Make the original `TensorBoard` log to a subdirectory 'training' training_log_dir = log_dir super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs) # Log the validation metrics to a separate subdirectory self.val_log_dir = './v_logs' self.dis_log_dir = './d_logs' self.gen_log_dir = './g_logs' def set_model(self, model): # Setup writer for validation metrics self.val_writer = tf.summary.FileWriter(self.val_log_dir) self.dis_writer = tf.summary.FileWriter(self.dis_log_dir) self.gen_writer = tf.summary.FileWriter(self.gen_log_dir) super(TrainValTensorBoard, self).set_model(model) def on_epoch_end(self, epoch, logs=None): # Pop the validation logs and handle them separately with # `self.val_writer`. Also rename the keys so that they can # be plotted on the same figure with the training metrics logs = logs or {} val_logs = {k.replace('val_', 'v_'): v for k, v in logs.items() if k.startswith('val_')} for name, value in val_logs.items(): summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.val_writer.add_summary(summary, epoch) self.val_writer.flush() logs = logs or {} dis_logs = {k.replace('discriminator_', 'd_'): v for k, v in logs.items() if k.startswith('discriminator_')} for name, value in dis_logs.items(): summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.dis_writer.add_summary(summary, epoch) self.dis_writer.flush() logs = logs or {} gen_logs = {k.replace('generator_', 'g_'): v for k, v in logs.items() if k.startswith('generator_')} for name, value in gen_logs.items(): summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.gen_writer.add_summary(summary, epoch) self.gen_writer.flush() # Pass the remaining logs to `TensorBoard.on_epoch_end` t_logs = {k: v for k, v in logs.items() if not k.startswith('val_')} tr_logs = {k: v for k, v in t_logs.items() if not k.startswith('discriminator_')} tra_logs = {k: v for k, v in tr_logs.items() if not k.startswith('generator_')} super(TrainValTensorBoard, self).on_epoch_end(epoch, tra_logs) def on_train_end(self, logs=None): super(TrainValTensorBoard, self).on_train_end(logs) self.val_writer.close() self.gen_writer.close() self.dis_writer.close() def rescale_labels(labels): rescaled_labels = 2*(labels - 0.5) return rescaled_labels def rescale_back_labels(labels): rescaled_back_labels =(labels/2)+0.5 return rescaled_back_labels def get_model(params, nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10): generator = get_generator(params, nb_classes) discriminator = get_discriminator(params, nb_classes) generator_model, discriminator_model = \ WGAN_wrapper(generator=generator, discriminator=discriminator, generator_input_shape=(MAXLEN,), discriminator_input_shape=(nb_classes,), discriminator_input_shape2 = (MAXLEN, ), batch_size=batch_size, gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) logging.info('Compilation finished') return generator_model, discriminator_model def train_wgan(generator_model, discriminator_model, batch_size, epochs, x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path, TRAINING_RATIO=10, N_WARM_UP=0): BATCH_SIZE = batch_size N_EPOCH = epochs positive_y = np.ones((batch_size, 1), dtype=np.float32) zero_y = positive_y * 0 negative_y = -positive_y positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32) dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32) positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 ) positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 ) #positive_enable_train = np.ones((1, batch_size),dtype = np.float32 ) #positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 ) best_validation_loss = None callback = TrainValTensorBoard(write_graph=False) #TensorBoard(log_path) callback.set_model(generator_model) for epoch in range(N_EPOCH): # np.random.shuffle(X_train) print("Epoch: ", epoch) print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE)) discriminator_loss = [] generator_loss = [] minibatches_size = BATCH_SIZE * TRAINING_RATIO shuffled_indexes = np.random.permutation(x_train.shape[0]) shuffled_indexes_2 = np.random.permutation(x_train.shape[0]) for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))): batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size] batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size] x = x_train[batch_indexes] y = y_train[batch_indexes] y_2 = y_train[batch_indexes_2] x_2 = x_train[batch_indexes_2] if epoch < N_WARM_UP: for j in range(TRAINING_RATIO): x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] logg = generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y]) generator_loss.append(logg) discriminator_loss.append(0) else: for j in range(TRAINING_RATIO): x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32) noise = x_batch discriminator_loss.append(discriminator_model.train_on_batch( [y_batch_2, noise, x_batch_2], [positive_y, negative_y, dummy_y])) logg = generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y]) generator_loss.append(logg) predicted_y_train, _ = generator_model.predict([x_train, positive_full_enable_train], batch_size=BATCH_SIZE) predicted_y_val, _ = generator_model.predict([x_val, positive_full_enable_val], batch_size=BATCH_SIZE) train_loss = log_loss(rescale_back_labels(y_train), rescale_back_labels(predicted_y_train)) val_loss = log_loss(y_val, rescale_back_labels(predicted_y_val)) callback.on_epoch_end(epoch, logs={'discriminator_loss': (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1), 'generator_loss': (np.sum(np.asarray(generator_loss)) if generator_loss else -1) ,'train_loss':train_loss ,'val_loss': val_loss}) print("train loss: {:.4f}, validation loss: {:.4f}, validation auc: {:.4f}, generator loss: {:.4f}, discriminator loss: {:.4f}".format( train_loss, val_loss, compute_roc(rescale_back_labels(predicted_y_val), y_val), (np.sum(np.asarray(generator_loss)) if generator_loss else -1) / x_train.shape[0], (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0])) if best_validation_loss is None or best_validation_loss > val_loss: print('\nEpoch %05d: improved from %0.5f,' ' saving model to %s and %s' % (epoch + 1, val_loss, generator_model_path, discriminator_model_path)) best_validation_loss = val_loss generator_model.save(generator_model_path, overwrite=True) discriminator_model.save(discriminator_model_path, overwrite=True) def model(params, batch_size=128, nb_epoch=50, is_train=True): start_time = time.time() logging.info("Loading Data") train_data, train_labels, test_data, test_labels, val_data, val_labels = load_data2() train_labels = rescale_labels(train_labels) nb_classes = train_labels.shape[1] logging.info("Data loaded in %d sec" % (time.time() - start_time)) logging.info("Training data size: %d" % len(train_data)) logging.info("Validation data size: %d" % len(val_data)) logging.info("Test data size: %d" % len(test_data)) generator_model_path = DATA_ROOT + 'models/new_model_seq_mt_10000_1_loaded_pretrained_new_' + FUNCTION + '.h5' #_loaded_pretrained_new discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_mt_' + FUNCTION + '.h5' logging.info('Starting training the model') train_generator = DataGenerator(batch_size, nb_classes) train_generator.fit(train_data, train_labels) valid_generator = DataGenerator(batch_size, nb_classes) valid_generator.fit(val_data, val_labels) test_generator = DataGenerator(batch_size, nb_classes) test_generator.fit(test_data, test_labels) if is_train: generator_model, discriminator_model = get_model(params, nb_classes, batch_size) generator_model.load_weights(DATA_ROOT + 'models/new_model_seq_mt_1_0_' + FUNCTION + '.h5') train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch, x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels, generator_model_path=generator_model_path, discriminator_model_path=discriminator_model_path) logging.info('Loading best model') model = load_model(generator_model_path, custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new, 'wasserstein_loss': wasserstein_loss}) logging.info('Predicting') preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0] preds = rescale_back_labels(preds) logging.info('Computing performance') f, p, r, t, preds_max = compute_performance(preds, test_labels) roc_auc = compute_roc(preds, test_labels) mcc = compute_mcc(preds_max, test_labels) aupr , _ = compute_aupr(preds, test_labels) m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T) tpr, total_rule = tpr_score(preds_max) logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f)) logging.info('ROC AUC: \t %f ' % (roc_auc, )) logging.info('MCC: \t %f ' % (mcc, )) logging.info('AUPR: \t %f ' % (aupr, )) logging.info('TPR from total rule: \t %f \t %f' % (tpr, total_rule)) logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) ) logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) ) function_centric_performance(preds.T, test_labels.T, rescale_back_labels(train_labels).T, t) def tpr_score(preds): hierarchy_f = open("data2/mtdnn_data_svmlight/" + FUNCTION + "_hierarchy.txt", "r") hiers = hierarchy_f.readlines() hiers.pop(0) tree = [] for i in range(len(hiers)): sample = [] str_sample = hiers[i].split() for j in range(1, len(str_sample)): sample.append(int(str_sample[j])) tree.append(sample) my_dict = {} for i in range(preds.shape[0]): my_dict[str(i)] = [] total_rule = 0 for i in range(len(tree)): for j in range(len(tree[i])): my_dict[str(tree[i][j]-1)].append(i) total_rule +=1 tpr = 0 for i in range(preds.shape[0]): for j in range(preds.shape[1]): if preds[i,j] ==1: parents = my_dict[str(j)] for k in range(len(parents)): if preds[i, parents[k]] == 0: tpr = tpr +1 return tpr/preds.shape[0], total_rule def function_centric_performance(preds, labels, labels_train, threshold): results = [] av_f = 0 preds = np.round(preds, 2) for i in range(preds.shape[0]): predictions = (preds[i, :] > threshold).astype(np.int32) tp = np.sum(predictions * labels[i, :]) fp = np.sum(predictions) - tp fn = np.sum(labels[i, :]) - tp if tp > 0: precision = tp / (1.0 * (tp + fp)) recall = tp / (1.0 * (tp + fn)) f = 2 * precision * recall / (precision + recall) else: if fp == 0 and fn == 0: precision = 1 recall = 1 f = 1 else: precision = 0 recall = 0 f = 0 f_max = f av_f += f p_max = precision r_max = recall num_prots_train = np.sum(labels_train[i, :]) results.append([mt_funcs[i], num_prots_train, f_max, p_max, r_max]) print(av_f/preds.shape[0]) results = pd.DataFrame(results) results.to_csv('MT_1_0_results' + FUNCTION + '.txt', sep='\t', index=False) predictions = (preds > threshold).astype(np.int32) hitmap_test = (np.dot(predictions, predictions.T)> 0).astype(np.int32) hitmap_train =(np.dot(labels_train, labels_train.T) > 0).astype(np.int32) logging.info('Heatmap scores:\t MSE:%f \t SSIM %f ' % (mean_squared_error(hitmap_test, hitmap_train), measure.compare_ssim(hitmap_test, hitmap_train))) #hitmap_test = pd.DataFrame(hitmap_test) #hitmap_test.to_csv('hitmap_1_1_' + FUNCTION + '.txt', sep='\t', index=False) #hitmap_train = pd.DataFrame(hitmap_train) #hitmap_train.to_csv('hitmap_train_' + FUNCTION + '.txt', sep='\t', index=False) #co_occure = np.zeros(shape=[predictions.shape[0], predictions.shape[0]]) #mutual_ex = np.zeros(shape=[predictions.shape[0], predictions.shape[0]]) #check_co_occur_test = np.zeros(shape=[predictions.shape[0], predictions.shape[0]]) #for i in range(predictions.shape[0]): # for j in range(predictions.shape[0]): # temp = np.logical_xor(labels_train[i,:], labels_train[j,:]) # if np.sum(temp) ==0: # co_occure[i, j] = 1 # temp = np.dot(labels_train[i,:], labels_train[j,:]) # if np.sum(temp) ==0: # mutual_ex[i,j] = 1 # check_co_occur_test[i, j] = (sum(np.logical_xor(labels_train[i,:], labels_train[j,:])) >0).astype(np.int32) #logging.info('Mutual_exclucivity_distorsion:\t %f ' % (sum(sum(np.multiply(mutual_ex, hitmap_test))))) #logging.info('Co_occurance_distorsion:\t %f ' % (sum(sum(np.multiply(co_occure, check_co_occur_test))))) def micro_macro_function_centric_f1(preds, labels): predictions = np.round(preds) m_tp = 0 m_fp = 0 m_fn = 0 M_pr = 0 M_rc = 0 for i in range(preds.shape[0]): tp = np.sum(predictions[i, :] * labels[i, :]) fp = np.sum(predictions[i, :]) - tp fn = np.sum(labels[i, :]) - tp m_tp += tp m_fp += fp m_fn += fn if tp > 0: pr = 1.0 * tp / (1.0 * (tp + fp)) rc = 1.0 * tp / (1.0 * (tp + fn)) else: if fp == 0 and fn == 0: pr = 1 rc = 1 else: pr = 0 rc = 0 M_pr += pr M_rc += rc if m_tp > 0: m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp)) m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn)) m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc) else: if m_fp == 0 and m_fn == 0: m_pr = 1 m_rc = 1 m_f1 = 1 else: m_pr = 0 m_rc = 0 m_f1 = 0 M_pr /= preds.shape[0] M_rc /= preds.shape[0] if M_pr == 0 and M_rc == 0: M_f1 = 0 else: M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc) return m_pr, m_rc, m_f1, M_pr, M_rc, M_f1 def compute_roc(preds, labels): # Compute ROC curve and ROC area for each class fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten()) roc_auc = auc(fpr, tpr) return roc_auc def compute_aupr(preds, labels): # Compute ROC curve and ROC area for each class pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten()) pr_auc = auc(rc, pr) #pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' ) M_pr_auc = 0 return pr_auc, M_pr_auc def compute_mcc(preds, labels): # Compute ROC curve and ROC area for each class mcc = matthews_corrcoef(labels.flatten(), preds.flatten()) return mcc def compute_performance(preds, labels): preds = np.round(preds, 2) f_max = 0 p_max = 0 r_max = 0 t_max = 0 for t in range(1, 100): threshold = t / 100.0 predictions = (preds > threshold).astype(np.int32) total = 0 f = 0.0 p = 0.0 r = 0.0 p_total = 0 for i in range(labels.shape[0]): tp = np.sum(predictions[i, :] * labels[i, :]) fp = np.sum(predictions[i, :]) - tp fn = np.sum(labels[i, :]) - tp if tp == 0 and fp == 0 and fn == 0: continue total += 1 if tp != 0: p_total += 1 precision = tp / (1.0 * (tp + fp)) recall = tp / (1.0 * (tp + fn)) p += precision r += recall if p_total == 0: continue r /= total p /= p_total if p + r > 0: f = 2 * p * r / (p + r) if f_max < f: f_max = f p_max = p r_max = r t_max = threshold predictions_max = predictions return f_max, p_max, r_max, t_max, predictions_max if __name__ == '__main__': main()