You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

pfp_conditional_wgan_exp2.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. #!/usr/bin/env python
  2. """
  3. """
  4. from __future__ import division
  5. import logging
  6. import sys
  7. import time
  8. import click as ck
  9. import numpy as np
  10. import pandas as pd
  11. import tensorflow as tf
  12. from keras import backend as K
  13. from keras.layers import (
  14. Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D,GaussianNoise,
  15. Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout)
  16. from keras.models import Sequential, Model, load_model
  17. from sklearn.metrics import log_loss
  18. from sklearn.metrics import roc_curve, auc, matthews_corrcoef, mean_squared_error
  19. from keras.layers import Lambda, LeakyReLU
  20. from sklearn.metrics import precision_recall_curve
  21. from keras.callbacks import TensorBoard
  22. from skimage import measure
  23. from utils import DataGenerator
  24. from conditional_wgan_wrapper_exp2 import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new
  25. config = tf.ConfigProto()
  26. config.gpu_options.allow_growth = True
  27. sess = tf.Session(config=config)
  28. K.set_session(sess)
  29. logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
  30. sys.setrecursionlimit(100000)
  31. DATA_ROOT = 'Data/data2/'
  32. MAXLEN = 258
  33. @ck.command()
  34. @ck.option(
  35. '--function',
  36. default='cc',
  37. help='Ontology id (mf, bp, cc)')
  38. @ck.option(
  39. '--device',
  40. default='gpu:0',
  41. help='GPU or CPU device id')
  42. @ck.option('--train',default =True, is_flag=True)
  43. def main(function, device, train):
  44. global FUNCTION
  45. FUNCTION = function
  46. with tf.device('/' + device):
  47. params = {
  48. 'fc_output': 1024,
  49. 'embedding_dims': 128,
  50. 'embedding_dropout': 0.2,
  51. 'nb_conv': 1,
  52. 'nb_dense': 1,
  53. 'filter_length': 128,
  54. 'nb_filter': 32,
  55. 'pool_length': 64,
  56. 'stride': 32
  57. }
  58. model(params, is_train=train)
  59. def load_data2():
  60. all_data_x_fn = 'data2/all_data_X.csv'
  61. all_data_x = pd.read_csv(all_data_x_fn, sep='\t', header=0, index_col=0)
  62. all_proteins_train = [p.replace('"', '') for p in all_data_x.index]
  63. all_data_x.index = all_proteins_train
  64. all_data_y_fn = 'data2/all_data_Y.csv'
  65. all_data_y = pd.read_csv(all_data_y_fn, sep='\t', header=0, index_col=0)
  66. branch = pd.read_csv('data2/'+FUNCTION +'_branches.txt', sep='\t', header=0, index_col=0)
  67. all_x = all_data_x.values
  68. branches = [p for p in branch.index.tolist() if p in all_data_y.columns.tolist()]
  69. branches = list(dict.fromkeys(branches))
  70. global mt_funcs
  71. mt_funcs = branches
  72. t= pd.DataFrame(all_data_y, columns=branches)
  73. all_y = t.values
  74. number_of_test = int(np.ceil(0.4 * len(all_x)))
  75. np.random.seed(0)
  76. index = np.random.rand(1,number_of_test)
  77. index_test = [int(p) for p in np.ceil(index*(len(all_x)-1))[0]]
  78. index_train = [p for p in range(len(all_x)) if p not in index_test]
  79. train_data = all_x[index_train, : ]
  80. test_x = all_x[index_test, : ]
  81. train_labels = all_y[index_train, : ]
  82. test_y = all_y[index_test, :]
  83. number_of_vals = int(np.ceil(0.5 * len(test_x)))
  84. np.random.seed(1)
  85. index = np.random.rand(1, number_of_vals)
  86. index_vals = [int(p) for p in np.ceil(index * (len(test_x) - 1))[0]]
  87. index_test = [p for p in range(len(test_x)) if p not in index_vals]
  88. val_data = test_x[index_vals, :]
  89. test_data = test_x[index_test, :]
  90. val_labels = test_y[index_vals, :]
  91. test_labels = test_y[index_test, :]
  92. print(train_labels.shape)
  93. return train_data, train_labels, test_data, test_labels, val_data, val_labels
  94. def get_feature_model(params):
  95. embedding_dims = params['embedding_dims']
  96. max_features = 8001
  97. model = Sequential()
  98. model.add(Embedding(
  99. max_features,
  100. embedding_dims,
  101. input_length=MAXLEN))
  102. model.add(SpatialDropout1D(0.4))
  103. for i in range(params['nb_conv']):
  104. model.add(Conv1D(
  105. padding="valid",
  106. strides=1,
  107. filters=params['nb_filter'],
  108. kernel_size=params['filter_length']))
  109. model.add(LeakyReLU())
  110. model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length']))
  111. model.add(Flatten())
  112. return model
  113. def get_generator(params, n_classes):
  114. inputs = Input(shape=(MAXLEN,), dtype='float32', name='input1')
  115. net0 = Dense(200)(inputs)
  116. #net0 = BatchNormalization()(net0)
  117. net0 = LeakyReLU()(net0)
  118. net0 = Dense(200)(net0)
  119. #net0 = BatchNormalization()(net0)
  120. net0 = LeakyReLU()(net0)
  121. #net0 = Dense(200, activation='relu')(inputs)
  122. net0 = Dense(200)(net0)
  123. #net0 = BatchNormalization()(net0)
  124. net = LeakyReLU()(net0)
  125. output = Dense(n_classes, activation='tanh')(net)
  126. model = Model(inputs=inputs, outputs=output)
  127. return model
  128. def get_discriminator(params, n_classes, dropout_rate=0.5):
  129. inputs = Input(shape=(n_classes, ))
  130. n_inputs = GaussianNoise(0.1)(inputs)
  131. inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2')
  132. x2 = Embedding(8001, 128, input_length=MAXLEN)(inputs2)
  133. x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', strides=1)(x2)
  134. x2 = LeakyReLU()(x2)
  135. x2 = Lambda(lambda x: K.squeeze(x, 2))(x2)
  136. #for i in range(params['nb_conv']):
  137. # x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2)
  138. #x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2)
  139. #x2 = Flatten()(x2)
  140. size = 40
  141. x = n_inputs
  142. x = Dropout(dropout_rate)(x)
  143. x = Dense(size)(x)
  144. x = BatchNormalization()(x)
  145. x = LeakyReLU()(x)
  146. size = 40
  147. x2 = Dropout(dropout_rate)(x2)
  148. x2 = Dense(size)(x2)
  149. x2 = BatchNormalization()(x2)
  150. x2 = LeakyReLU()(x2)
  151. x = Concatenate(axis =1 , name = 'merged2')([x, x2])
  152. layer_sizes = [30, 30, 30, 30, 30, 30]
  153. for size in layer_sizes:
  154. x = Dropout(dropout_rate)(x)
  155. x = Dense(size)(x)
  156. x = BatchNormalization()(x)
  157. x = LeakyReLU()(x)
  158. outputs = Dense(1)(x)
  159. model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator')
  160. return model
  161. class TrainValTensorBoard(TensorBoard):
  162. def __init__(self, log_dir='./t_logs', **kwargs):
  163. # Make the original `TensorBoard` log to a subdirectory 'training'
  164. training_log_dir = log_dir
  165. super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
  166. # Log the validation metrics to a separate subdirectory
  167. self.val_log_dir = './v_logs'
  168. self.dis_log_dir = './d_logs'
  169. self.gen_log_dir = './g_logs'
  170. def set_model(self, model):
  171. # Setup writer for validation metrics
  172. self.val_writer = tf.summary.FileWriter(self.val_log_dir)
  173. self.dis_writer = tf.summary.FileWriter(self.dis_log_dir)
  174. self.gen_writer = tf.summary.FileWriter(self.gen_log_dir)
  175. super(TrainValTensorBoard, self).set_model(model)
  176. def on_epoch_end(self, epoch, logs=None):
  177. # Pop the validation logs and handle them separately with
  178. # `self.val_writer`. Also rename the keys so that they can
  179. # be plotted on the same figure with the training metrics
  180. logs = logs or {}
  181. val_logs = {k.replace('val_', 'v_'): v for k, v in logs.items() if k.startswith('val_')}
  182. for name, value in val_logs.items():
  183. summary = tf.Summary()
  184. summary_value = summary.value.add()
  185. summary_value.simple_value = value.item()
  186. summary_value.tag = name
  187. self.val_writer.add_summary(summary, epoch)
  188. self.val_writer.flush()
  189. logs = logs or {}
  190. dis_logs = {k.replace('discriminator_', 'd_'): v for k, v in logs.items() if k.startswith('discriminator_')}
  191. for name, value in dis_logs.items():
  192. summary = tf.Summary()
  193. summary_value = summary.value.add()
  194. summary_value.simple_value = value.item()
  195. summary_value.tag = name
  196. self.dis_writer.add_summary(summary, epoch)
  197. self.dis_writer.flush()
  198. logs = logs or {}
  199. gen_logs = {k.replace('generator_', 'g_'): v for k, v in logs.items() if k.startswith('generator_')}
  200. for name, value in gen_logs.items():
  201. summary = tf.Summary()
  202. summary_value = summary.value.add()
  203. summary_value.simple_value = value.item()
  204. summary_value.tag = name
  205. self.gen_writer.add_summary(summary, epoch)
  206. self.gen_writer.flush()
  207. # Pass the remaining logs to `TensorBoard.on_epoch_end`
  208. t_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
  209. tr_logs = {k: v for k, v in t_logs.items() if not k.startswith('discriminator_')}
  210. tra_logs = {k: v for k, v in tr_logs.items() if not k.startswith('generator_')}
  211. super(TrainValTensorBoard, self).on_epoch_end(epoch, tra_logs)
  212. def on_train_end(self, logs=None):
  213. super(TrainValTensorBoard, self).on_train_end(logs)
  214. self.val_writer.close()
  215. self.gen_writer.close()
  216. self.dis_writer.close()
  217. def rescale_labels(labels):
  218. rescaled_labels = 2*(labels - 0.5)
  219. return rescaled_labels
  220. def rescale_back_labels(labels):
  221. rescaled_back_labels =(labels/2)+0.5
  222. return rescaled_back_labels
  223. def get_model(params, nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10):
  224. generator = get_generator(params, nb_classes)
  225. discriminator = get_discriminator(params, nb_classes)
  226. generator_model, discriminator_model = \
  227. WGAN_wrapper(generator=generator,
  228. discriminator=discriminator,
  229. generator_input_shape=(MAXLEN,),
  230. discriminator_input_shape=(nb_classes,),
  231. discriminator_input_shape2 = (MAXLEN, ),
  232. batch_size=batch_size,
  233. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
  234. logging.info('Compilation finished')
  235. return generator_model, discriminator_model
  236. def train_wgan(generator_model, discriminator_model, batch_size, epochs,
  237. x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path,
  238. TRAINING_RATIO=10, N_WARM_UP=0):
  239. BATCH_SIZE = batch_size
  240. N_EPOCH = epochs
  241. positive_y = np.ones((batch_size, 1), dtype=np.float32)
  242. zero_y = positive_y * 0
  243. negative_y = -positive_y
  244. positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32)
  245. dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
  246. positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 )
  247. positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 )
  248. #positive_enable_train = np.ones((1, batch_size),dtype = np.float32 )
  249. #positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 )
  250. best_validation_loss = None
  251. callback = TrainValTensorBoard(write_graph=False) #TensorBoard(log_path)
  252. callback.set_model(generator_model)
  253. for epoch in range(N_EPOCH):
  254. # np.random.shuffle(X_train)
  255. print("Epoch: ", epoch)
  256. print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE))
  257. discriminator_loss = []
  258. generator_loss = []
  259. minibatches_size = BATCH_SIZE * TRAINING_RATIO
  260. shuffled_indexes = np.random.permutation(x_train.shape[0])
  261. shuffled_indexes_2 = np.random.permutation(x_train.shape[0])
  262. for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))):
  263. batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size]
  264. batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size]
  265. x = x_train[batch_indexes]
  266. y = y_train[batch_indexes]
  267. y_2 = y_train[batch_indexes_2]
  268. x_2 = x_train[batch_indexes_2]
  269. if epoch < N_WARM_UP:
  270. for j in range(TRAINING_RATIO):
  271. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  272. y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  273. logg = generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y])
  274. generator_loss.append(logg)
  275. discriminator_loss.append(0)
  276. else:
  277. for j in range(TRAINING_RATIO):
  278. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  279. y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  280. x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  281. # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
  282. noise = x_batch
  283. discriminator_loss.append(discriminator_model.train_on_batch(
  284. [y_batch_2, noise, x_batch_2],
  285. [positive_y, negative_y, dummy_y]))
  286. logg = generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y])
  287. generator_loss.append(logg)
  288. predicted_y_train, _ = generator_model.predict([x_train, positive_full_enable_train], batch_size=BATCH_SIZE)
  289. predicted_y_val, _ = generator_model.predict([x_val, positive_full_enable_val], batch_size=BATCH_SIZE)
  290. train_loss = log_loss(rescale_back_labels(y_train), rescale_back_labels(predicted_y_train))
  291. val_loss = log_loss(y_val, rescale_back_labels(predicted_y_val))
  292. callback.on_epoch_end(epoch, logs={'discriminator_loss': (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1),
  293. 'generator_loss': (np.sum(np.asarray(generator_loss)) if generator_loss else -1) ,'train_loss':train_loss ,'val_loss': val_loss})
  294. print("train loss: {:.4f}, validation loss: {:.4f}, validation auc: {:.4f}, generator loss: {:.4f}, discriminator loss: {:.4f}".format(
  295. train_loss, val_loss, compute_roc(rescale_back_labels(predicted_y_val), y_val),
  296. (np.sum(np.asarray(generator_loss)) if generator_loss else -1) / x_train.shape[0],
  297. (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0]))
  298. if best_validation_loss is None or best_validation_loss > val_loss:
  299. print('\nEpoch %05d: improved from %0.5f,'
  300. ' saving model to %s and %s'
  301. % (epoch + 1, val_loss, generator_model_path, discriminator_model_path))
  302. best_validation_loss = val_loss
  303. generator_model.save(generator_model_path, overwrite=True)
  304. discriminator_model.save(discriminator_model_path, overwrite=True)
  305. def model(params, batch_size=128, nb_epoch=50, is_train=True):
  306. start_time = time.time()
  307. logging.info("Loading Data")
  308. train_data, train_labels, test_data, test_labels, val_data, val_labels = load_data2()
  309. train_labels = rescale_labels(train_labels)
  310. nb_classes = train_labels.shape[1]
  311. logging.info("Data loaded in %d sec" % (time.time() - start_time))
  312. logging.info("Training data size: %d" % len(train_data))
  313. logging.info("Validation data size: %d" % len(val_data))
  314. logging.info("Test data size: %d" % len(test_data))
  315. generator_model_path = DATA_ROOT + 'models/new_model_seq_mt_10000_1_loaded_pretrained_new_' + FUNCTION + '.h5' #_loaded_pretrained_new
  316. discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_mt_' + FUNCTION + '.h5'
  317. logging.info('Starting training the model')
  318. train_generator = DataGenerator(batch_size, nb_classes)
  319. train_generator.fit(train_data, train_labels)
  320. valid_generator = DataGenerator(batch_size, nb_classes)
  321. valid_generator.fit(val_data, val_labels)
  322. test_generator = DataGenerator(batch_size, nb_classes)
  323. test_generator.fit(test_data, test_labels)
  324. if is_train:
  325. generator_model, discriminator_model = get_model(params, nb_classes, batch_size)
  326. generator_model.load_weights(DATA_ROOT + 'models/new_model_seq_mt_1_0_' + FUNCTION + '.h5')
  327. train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch,
  328. x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels,
  329. generator_model_path=generator_model_path,
  330. discriminator_model_path=discriminator_model_path)
  331. logging.info('Loading best model')
  332. model = load_model(generator_model_path,
  333. custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new,
  334. 'wasserstein_loss': wasserstein_loss})
  335. logging.info('Predicting')
  336. preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0]
  337. preds = rescale_back_labels(preds)
  338. logging.info('Computing performance')
  339. f, p, r, t, preds_max = compute_performance(preds, test_labels)
  340. roc_auc = compute_roc(preds, test_labels)
  341. mcc = compute_mcc(preds_max, test_labels)
  342. aupr , _ = compute_aupr(preds, test_labels)
  343. m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T)
  344. tpr, total_rule = tpr_score(preds_max)
  345. logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f))
  346. logging.info('ROC AUC: \t %f ' % (roc_auc, ))
  347. logging.info('MCC: \t %f ' % (mcc, ))
  348. logging.info('AUPR: \t %f ' % (aupr, ))
  349. logging.info('TPR from total rule: \t %f \t %f' % (tpr, total_rule))
  350. logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) )
  351. logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) )
  352. function_centric_performance(preds.T, test_labels.T, rescale_back_labels(train_labels).T, t)
  353. def tpr_score(preds):
  354. hierarchy_f = open("data2/mtdnn_data_svmlight/" + FUNCTION + "_hierarchy.txt", "r")
  355. hiers = hierarchy_f.readlines()
  356. hiers.pop(0)
  357. tree = []
  358. for i in range(len(hiers)):
  359. sample = []
  360. str_sample = hiers[i].split()
  361. for j in range(1, len(str_sample)):
  362. sample.append(int(str_sample[j]))
  363. tree.append(sample)
  364. my_dict = {}
  365. for i in range(preds.shape[0]):
  366. my_dict[str(i)] = []
  367. total_rule = 0
  368. for i in range(len(tree)):
  369. for j in range(len(tree[i])):
  370. my_dict[str(tree[i][j]-1)].append(i)
  371. total_rule +=1
  372. tpr = 0
  373. for i in range(preds.shape[0]):
  374. for j in range(preds.shape[1]):
  375. if preds[i,j] ==1:
  376. parents = my_dict[str(j)]
  377. for k in range(len(parents)):
  378. if preds[i, parents[k]] == 0:
  379. tpr = tpr +1
  380. return tpr/preds.shape[0], total_rule
  381. def function_centric_performance(preds, labels, labels_train, threshold):
  382. results = []
  383. av_f = 0
  384. preds = np.round(preds, 2)
  385. for i in range(preds.shape[0]):
  386. predictions = (preds[i, :] > threshold).astype(np.int32)
  387. tp = np.sum(predictions * labels[i, :])
  388. fp = np.sum(predictions) - tp
  389. fn = np.sum(labels[i, :]) - tp
  390. if tp > 0:
  391. precision = tp / (1.0 * (tp + fp))
  392. recall = tp / (1.0 * (tp + fn))
  393. f = 2 * precision * recall / (precision + recall)
  394. else:
  395. if fp == 0 and fn == 0:
  396. precision = 1
  397. recall = 1
  398. f = 1
  399. else:
  400. precision = 0
  401. recall = 0
  402. f = 0
  403. f_max = f
  404. av_f += f
  405. p_max = precision
  406. r_max = recall
  407. num_prots_train = np.sum(labels_train[i, :])
  408. results.append([mt_funcs[i], num_prots_train, f_max, p_max, r_max])
  409. print(av_f/preds.shape[0])
  410. results = pd.DataFrame(results)
  411. results.to_csv('MT_1_0_results' + FUNCTION + '.txt', sep='\t', index=False)
  412. predictions = (preds > threshold).astype(np.int32)
  413. hitmap_test = (np.dot(predictions, predictions.T)> 0).astype(np.int32)
  414. hitmap_train =(np.dot(labels_train, labels_train.T) > 0).astype(np.int32)
  415. logging.info('Heatmap scores:\t MSE:%f \t SSIM %f ' % (mean_squared_error(hitmap_test, hitmap_train), measure.compare_ssim(hitmap_test, hitmap_train)))
  416. #hitmap_test = pd.DataFrame(hitmap_test)
  417. #hitmap_test.to_csv('hitmap_1_1_' + FUNCTION + '.txt', sep='\t', index=False)
  418. #hitmap_train = pd.DataFrame(hitmap_train)
  419. #hitmap_train.to_csv('hitmap_train_' + FUNCTION + '.txt', sep='\t', index=False)
  420. #co_occure = np.zeros(shape=[predictions.shape[0], predictions.shape[0]])
  421. #mutual_ex = np.zeros(shape=[predictions.shape[0], predictions.shape[0]])
  422. #check_co_occur_test = np.zeros(shape=[predictions.shape[0], predictions.shape[0]])
  423. #for i in range(predictions.shape[0]):
  424. # for j in range(predictions.shape[0]):
  425. # temp = np.logical_xor(labels_train[i,:], labels_train[j,:])
  426. # if np.sum(temp) ==0:
  427. # co_occure[i, j] = 1
  428. # temp = np.dot(labels_train[i,:], labels_train[j,:])
  429. # if np.sum(temp) ==0:
  430. # mutual_ex[i,j] = 1
  431. # check_co_occur_test[i, j] = (sum(np.logical_xor(labels_train[i,:], labels_train[j,:])) >0).astype(np.int32)
  432. #logging.info('Mutual_exclucivity_distorsion:\t %f ' % (sum(sum(np.multiply(mutual_ex, hitmap_test)))))
  433. #logging.info('Co_occurance_distorsion:\t %f ' % (sum(sum(np.multiply(co_occure, check_co_occur_test)))))
  434. def micro_macro_function_centric_f1(preds, labels):
  435. predictions = np.round(preds)
  436. m_tp = 0
  437. m_fp = 0
  438. m_fn = 0
  439. M_pr = 0
  440. M_rc = 0
  441. for i in range(preds.shape[0]):
  442. tp = np.sum(predictions[i, :] * labels[i, :])
  443. fp = np.sum(predictions[i, :]) - tp
  444. fn = np.sum(labels[i, :]) - tp
  445. m_tp += tp
  446. m_fp += fp
  447. m_fn += fn
  448. if tp > 0:
  449. pr = 1.0 * tp / (1.0 * (tp + fp))
  450. rc = 1.0 * tp / (1.0 * (tp + fn))
  451. else:
  452. if fp == 0 and fn == 0:
  453. pr = 1
  454. rc = 1
  455. else:
  456. pr = 0
  457. rc = 0
  458. M_pr += pr
  459. M_rc += rc
  460. if m_tp > 0:
  461. m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp))
  462. m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn))
  463. m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc)
  464. else:
  465. if m_fp == 0 and m_fn == 0:
  466. m_pr = 1
  467. m_rc = 1
  468. m_f1 = 1
  469. else:
  470. m_pr = 0
  471. m_rc = 0
  472. m_f1 = 0
  473. M_pr /= preds.shape[0]
  474. M_rc /= preds.shape[0]
  475. if M_pr == 0 and M_rc == 0:
  476. M_f1 = 0
  477. else:
  478. M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc)
  479. return m_pr, m_rc, m_f1, M_pr, M_rc, M_f1
  480. def compute_roc(preds, labels):
  481. # Compute ROC curve and ROC area for each class
  482. fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
  483. roc_auc = auc(fpr, tpr)
  484. return roc_auc
  485. def compute_aupr(preds, labels):
  486. # Compute ROC curve and ROC area for each class
  487. pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten())
  488. pr_auc = auc(rc, pr)
  489. #pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' )
  490. M_pr_auc = 0
  491. return pr_auc, M_pr_auc
  492. def compute_mcc(preds, labels):
  493. # Compute ROC curve and ROC area for each class
  494. mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
  495. return mcc
  496. def compute_performance(preds, labels):
  497. preds = np.round(preds, 2)
  498. f_max = 0
  499. p_max = 0
  500. r_max = 0
  501. t_max = 0
  502. for t in range(1, 100):
  503. threshold = t / 100.0
  504. predictions = (preds > threshold).astype(np.int32)
  505. total = 0
  506. f = 0.0
  507. p = 0.0
  508. r = 0.0
  509. p_total = 0
  510. for i in range(labels.shape[0]):
  511. tp = np.sum(predictions[i, :] * labels[i, :])
  512. fp = np.sum(predictions[i, :]) - tp
  513. fn = np.sum(labels[i, :]) - tp
  514. if tp == 0 and fp == 0 and fn == 0:
  515. continue
  516. total += 1
  517. if tp != 0:
  518. p_total += 1
  519. precision = tp / (1.0 * (tp + fp))
  520. recall = tp / (1.0 * (tp + fn))
  521. p += precision
  522. r += recall
  523. if p_total == 0:
  524. continue
  525. r /= total
  526. p /= p_total
  527. if p + r > 0:
  528. f = 2 * p * r / (p + r)
  529. if f_max < f:
  530. f_max = f
  531. p_max = p
  532. r_max = r
  533. t_max = threshold
  534. predictions_max = predictions
  535. return f_max, p_max, r_max, t_max, predictions_max
  536. if __name__ == '__main__':
  537. main()