You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

conditional_new_nn_hirearchical_seq_mt.py 26KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. #!/usr/bin/env python
  2. """
  3. """
  4. from __future__ import division
  5. import logging
  6. import sys
  7. import time
  8. from collections import deque
  9. from multiprocessing import Pool
  10. import click as ck
  11. import numpy as np
  12. import pandas as pd
  13. import tensorflow as tf
  14. from keras import backend as K
  15. from keras.callbacks import EarlyStopping, ModelCheckpoint
  16. from keras.layers import (
  17. Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D,
  18. Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout)
  19. from keras.losses import binary_crossentropy
  20. from keras.models import Sequential, Model, load_model
  21. from keras.preprocessing import sequence
  22. from scipy.spatial import distance
  23. from sklearn.metrics import log_loss
  24. from sklearn.metrics import roc_curve, auc, matthews_corrcoef
  25. from keras.layers import Lambda
  26. from sklearn.metrics import precision_recall_curve
  27. from utils import (
  28. get_gene_ontology,
  29. get_go_set,
  30. get_anchestors,
  31. get_parents,
  32. DataGenerator,
  33. FUNC_DICT,
  34. get_height,
  35. get_ipro)
  36. from conditional_wgan_wrapper import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new
  37. config = tf.ConfigProto()
  38. config.gpu_options.allow_growth = True
  39. sess = tf.Session(config=config)
  40. K.set_session(sess)
  41. logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
  42. sys.setrecursionlimit(100000)
  43. DATA_ROOT = 'data/swiss/'
  44. MAXLEN = 258 #1000
  45. REPLEN = 256
  46. ind = 0
  47. @ck.command()
  48. @ck.option(
  49. '--function',
  50. default='bp',
  51. help='Ontology id (mf, bp, cc)')
  52. @ck.option(
  53. '--device',
  54. default='gpu:0',
  55. help='GPU or CPU device id')
  56. @ck.option(
  57. '--org',
  58. default= None,
  59. help='Organism id for filtering test set')
  60. @ck.option('--train',default = True, is_flag=True)
  61. @ck.option('--param', default=0, help='Param index 0-7')
  62. def main(function, device, org, train, param):
  63. global FUNCTION
  64. FUNCTION = function
  65. global GO_ID
  66. GO_ID = FUNC_DICT[FUNCTION]
  67. global go
  68. go = get_gene_ontology('go.obo')
  69. global ORG
  70. ORG = org
  71. func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl')
  72. global functions
  73. functions = func_df['functions'].values
  74. global func_set
  75. func_set = set(functions)
  76. global all_functions
  77. all_functions = get_go_set(go, GO_ID)
  78. logging.info('Functions: %s %d' % (FUNCTION, len(functions)))
  79. if ORG is not None:
  80. logging.info('Organism %s' % ORG)
  81. global go_indexes
  82. go_indexes = dict()
  83. for ind, go_id in enumerate(functions):
  84. go_indexes[go_id] = ind
  85. global node_names
  86. node_names = set()
  87. with tf.device('/' + device):
  88. params = {
  89. 'fc_output': 1024,
  90. 'learning_rate': 0.001,
  91. 'embedding_dims': 128,
  92. 'embedding_dropout': 0.2,
  93. 'nb_conv': 1,
  94. 'nb_dense': 1,
  95. 'filter_length': 128,
  96. 'nb_filter': 32,
  97. 'pool_length': 64,
  98. 'stride': 32
  99. }
  100. model(params, is_train=train)
  101. def load_data2():
  102. all_data_x_fn = 'data2/all_data_X.csv'
  103. all_data_x = pd.read_csv(all_data_x_fn, sep='\t', header=0, index_col=0)
  104. all_proteins_train = [p.replace('"', '') for p in all_data_x.index]
  105. all_data_x.index = all_proteins_train
  106. all_data_y_fn = 'data2/all_data_Y.csv'
  107. all_data_y = pd.read_csv(all_data_y_fn, sep='\t', header=0, index_col=0)
  108. branch = pd.read_csv('data2/'+FUNCTION +'_branches.txt', sep='\t', header=0, index_col=0)
  109. all_x = all_data_x.values
  110. branches = [p for p in branch.index.tolist() if p in all_data_y.columns.tolist()]
  111. t= pd.DataFrame(all_data_y, columns=branches)
  112. all_y = t.values
  113. number_of_test = int(np.ceil(0.2 * len(all_x)))
  114. index = np.random.rand(1,number_of_test)
  115. index_test = [int(p) for p in np.ceil(index*len(all_x))[0] ]
  116. index_train = [p for p in range(len(all_x)) if p not in index_test]
  117. train_data = all_x[index_train, : ] #[ :20000, : ]
  118. test_data = all_x[index_test, : ] #[20000: , : ]
  119. train_labels = all_y[index_train, : ] #[ :20000, : ]
  120. test_labels = all_y[index_test, :] #[20000: , : ]
  121. val_data = test_data
  122. val_labels = test_labels
  123. #print(sum(sum(train_labels)))
  124. #print(train_data.shape)
  125. print(train_labels.shape)
  126. print(test_labels.shape)
  127. return train_data, train_labels, test_data, test_labels, val_data, val_labels
  128. def load_data():
  129. df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
  130. n = len(df)
  131. index = df.index.values
  132. valid_n = int(n * 0.8)
  133. train_df = df.loc[index[:valid_n]]
  134. valid_df = df.loc[index[valid_n:]]
  135. test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
  136. print( test_df['orgs'] )
  137. if ORG is not None:
  138. logging.info('Unfiltered test size: %d' % len(test_df))
  139. test_df = test_df[test_df['orgs'] == ORG]
  140. logging.info('Filtered test size: %d' % len(test_df))
  141. # Filter by type
  142. # org_df = pd.read_pickle('data/prokaryotes.pkl')
  143. # orgs = org_df['orgs']
  144. # test_df = test_df[test_df['orgs'].isin(orgs)]
  145. def reshape(values):
  146. values = np.hstack(values).reshape(
  147. len(values), len(values[0]))
  148. return values
  149. def normalize_minmax(values):
  150. mn = np.min(values)
  151. mx = np.max(values)
  152. if mx - mn != 0.0:
  153. return (values - mn) / (mx - mn)
  154. return values - mn
  155. def get_values(data_frame):
  156. print(data_frame['labels'].values.shape)
  157. labels = reshape(data_frame['labels'].values)
  158. ngrams = sequence.pad_sequences(
  159. data_frame['ngrams'].values, maxlen=MAXLEN)
  160. ngrams = reshape(ngrams)
  161. rep = reshape(data_frame['embeddings'].values)
  162. data = ngrams
  163. return data, labels
  164. train = get_values(train_df)
  165. valid = get_values(valid_df)
  166. test = get_values(test_df)
  167. return train, valid, test, train_df, valid_df, test_df
  168. def get_feature_model(params):
  169. embedding_dims = params['embedding_dims']
  170. max_features = 8001
  171. model = Sequential()
  172. model.add(Embedding(
  173. max_features,
  174. embedding_dims,
  175. input_length=MAXLEN))
  176. model.add(SpatialDropout1D(0.4))
  177. for i in range(params['nb_conv']):
  178. model.add(Conv1D(
  179. activation="relu",
  180. padding="valid",
  181. strides=1,
  182. filters=params['nb_filter'],
  183. kernel_size=params['filter_length']))
  184. model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length']))
  185. model.add(Flatten())
  186. return model
  187. def get_generator(params, n_classes):
  188. inputs = Input(shape=(MAXLEN,), dtype='float32', name='input1')
  189. #feature_model = get_feature_model(params)(inputs)
  190. net0 = Dense(150, activation='relu')(inputs)
  191. net0 = Dense(150, activation='relu')(net0)
  192. #net0 = Dense(50, activation='relu')(net0)
  193. net = Dense(70, activation = 'relu')(net0)
  194. output = Dense(n_classes, activation='sigmoid')(net)
  195. model = Model(inputs=inputs, outputs=output)
  196. return model
  197. def get_discriminator(params, n_classes, dropout_rate=0.5):
  198. inputs = Input(shape=(n_classes, ))
  199. inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2')
  200. x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2)
  201. x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', activation ='relu', strides=1)(x2)
  202. x2 = Lambda(lambda x: K.squeeze(x, 2))(x2)
  203. #for i in range(params['nb_conv']):
  204. # x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2)
  205. #x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2)
  206. #x2 = Flatten()(x2)
  207. size = 40
  208. x = inputs
  209. x = Dropout(dropout_rate)(x)
  210. x = Dense(size)(x)
  211. x = BatchNormalization()(x)
  212. x = Activation('relu')(x)
  213. size = 40
  214. x2 = Dropout(dropout_rate)(x2)
  215. x2 = Dense(size)(x2)
  216. x2 = BatchNormalization()(x2)
  217. x2 = Activation('relu')(x2)
  218. x = Concatenate(axis =1 , name = 'merged2')([x, x2])
  219. layer_sizes = [80, 40,30]
  220. for size in layer_sizes:
  221. x = Dropout(dropout_rate)(x)
  222. x = Dense(size)(x)
  223. x = BatchNormalization()(x)
  224. x = Activation('relu')(x)
  225. outputs = Dense(1)(x)
  226. model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator')
  227. return model
  228. def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10):
  229. generator = get_generator(params, nb_classes)
  230. discriminator = get_discriminator(params, nb_classes)
  231. generator_model, discriminator_model = \
  232. WGAN_wrapper(generator=generator,
  233. discriminator=discriminator,
  234. generator_input_shape=(MAXLEN,),
  235. discriminator_input_shape=(nb_classes,),
  236. discriminator_input_shape2 = (MAXLEN, ),
  237. batch_size=batch_size,
  238. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
  239. logging.info('Compilation finished')
  240. return generator_model, discriminator_model
  241. def train_wgan(generator_model, discriminator_model, batch_size, epochs,
  242. x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path,
  243. TRAINING_RATIO=10, N_WARM_UP=0):
  244. BATCH_SIZE = batch_size
  245. N_EPOCH = epochs
  246. positive_y = np.ones((batch_size, 1), dtype=np.float32)
  247. zero_y = positive_y * 0
  248. negative_y = -positive_y
  249. positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32)
  250. dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
  251. positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 )
  252. positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 )
  253. #positive_enable_train = np.ones((1, batch_size),dtype = np.float32 )
  254. #positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 )
  255. best_validation_loss = None
  256. for epoch in range(N_EPOCH):
  257. # np.random.shuffle(X_train)
  258. print("Epoch: ", epoch)
  259. print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE))
  260. discriminator_loss = []
  261. generator_loss = []
  262. minibatches_size = BATCH_SIZE * TRAINING_RATIO
  263. shuffled_indexes = np.random.permutation(x_train.shape[0])
  264. shuffled_indexes_2 = np.random.permutation(x_train.shape[0])
  265. for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))):
  266. batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size]
  267. batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size]
  268. x = x_train[batch_indexes]
  269. y = y_train[batch_indexes]
  270. y_2 = y_train[batch_indexes_2]
  271. x_2 = x_train[batch_indexes_2]
  272. if epoch < N_WARM_UP:
  273. for j in range(TRAINING_RATIO):
  274. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  275. y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  276. generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y]))
  277. else:
  278. for j in range(TRAINING_RATIO):
  279. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  280. y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  281. x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  282. # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
  283. noise = x_batch
  284. #print(sum(y_batch_2))
  285. discriminator_loss.append(discriminator_model.train_on_batch(
  286. [y_batch_2, noise, x_batch_2 ],
  287. [positive_y, negative_y, dummy_y]))
  288. generator_loss.append(generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y]))
  289. # Still needs some code to display losses from the generator and discriminator, progress bars, etc.
  290. predicted_y_train, _ = generator_model.predict([x_train , positive_full_enable_train], batch_size=BATCH_SIZE)
  291. predicted_y_val, _ = generator_model.predict([ x_val , positive_full_enable_val ], batch_size=BATCH_SIZE)
  292. #print(sum(sum(positive_full_enable_train)))
  293. #print(predicted_y_train)
  294. train_loss = log_loss(y_train, predicted_y_train)
  295. val_loss = log_loss(y_val, predicted_y_val)
  296. print("train loss: {:.4f}, validation loss: {:.4f}, discriminator loss: {:.4f}".format(
  297. train_loss, val_loss,
  298. (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0]))
  299. if best_validation_loss is None or best_validation_loss > val_loss:
  300. print('\nEpoch %05d: improved from %0.5f,'
  301. ' saving model to %s and %s'
  302. % (epoch + 1, val_loss, generator_model_path, discriminator_model_path))
  303. best_validation_loss = val_loss
  304. generator_model.save(generator_model_path, overwrite=True)
  305. discriminator_model.save(discriminator_model_path, overwrite=True)
  306. def model(params, batch_size=20, nb_epoch=40, is_train=True):
  307. # set parameters:
  308. #nb_classes = len(functions)
  309. start_time = time.time()
  310. logging.info("Loading Data")
  311. ##
  312. #train, val, test, train_df, valid_df, test_df = load_data()
  313. #train_df = pd.concat([train_df, valid_df])
  314. #test_gos = test_df['gos'].values
  315. #train_data, train_labels = train
  316. #val_data, val_labels = val
  317. #test_data, test_labels = test
  318. ##
  319. train_data, train_labels, test_data, test_labels, val_data, val_labels = load_data2()
  320. nb_classes = train_labels.shape[1]
  321. logging.info("Data loaded in %d sec" % (time.time() - start_time))
  322. logging.info("Training data size: %d" % len(train_data))
  323. logging.info("Validation data size: %d" % len(val_data))
  324. logging.info("Test data size: %d" % len(test_data))
  325. generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5'
  326. discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5'
  327. logging.info('Starting training the model')
  328. train_generator = DataGenerator(batch_size, nb_classes)
  329. train_generator.fit(train_data, train_labels)
  330. valid_generator = DataGenerator(batch_size, nb_classes)
  331. valid_generator.fit(val_data, val_labels)
  332. test_generator = DataGenerator(batch_size, nb_classes)
  333. test_generator.fit(test_data, test_labels)
  334. if is_train:
  335. generator_model, discriminator_model = get_model(params, nb_classes, batch_size)
  336. train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch,
  337. x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels,
  338. generator_model_path=generator_model_path,
  339. discriminator_model_path=discriminator_model_path)
  340. logging.info('Loading best model')
  341. model = load_model(generator_model_path,
  342. custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new,
  343. 'wasserstein_loss': wasserstein_loss})
  344. logging.info('Predicting')
  345. preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0]
  346. # incon = 0
  347. # for i in xrange(len(test_data)):
  348. # for j in xrange(len(functions)):
  349. # childs = set(go[functions[j]]['children']).intersection(func_set)
  350. # ok = True`
  351. # for n_id in childs:
  352. # if preds[i, j] < preds[i, go_indexes[n_id]]:
  353. # preds[i, j] = preds[i, go_indexes[n_id]]
  354. # ok = False
  355. # if not ok:
  356. # incon += 1
  357. logging.info('Computing performance')
  358. f, p, r, t, preds_max = compute_performance(preds, test_labels) #, test_gos)
  359. roc_auc = compute_roc(preds, test_labels)
  360. mcc = compute_mcc(preds_max, test_labels)
  361. aupr , _ = compute_aupr(preds, test_labels)
  362. m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T)
  363. logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f))
  364. logging.info('ROC AUC: \t %f ' % (roc_auc, ))
  365. logging.info('MCC: \t %f ' % (mcc, ))
  366. logging.info('AUPR: \t %f ' % (aupr, ))
  367. logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) )
  368. logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) )
  369. function_centric_performance(functions, preds.T, test_labels.T, train_labels.T)
  370. def function_centric_performance(functions, preds, labels, labels_train):
  371. results = []
  372. preds = np.round(preds, 2)
  373. for i in range(preds.shape[0]):
  374. f_max = 0
  375. p_max = 0
  376. r_max = 0
  377. for t in range(1, 100):
  378. threshold = t / 100.0
  379. predictions = (preds[i, :] > threshold).astype(np.int32)
  380. tp = np.sum(predictions * labels[i, :])
  381. fp = np.sum(predictions) - tp
  382. fn = np.sum(labels[i, :]) - tp
  383. if tp > 0:
  384. precision = tp / (1.0 * (tp + fp))
  385. recall = tp / (1.0 * (tp + fn))
  386. f = 2 * precision * recall / (precision + recall)
  387. else:
  388. if fp == 0 and fn == 0:
  389. precision = 1
  390. recall = 1
  391. f = 1
  392. else:
  393. precision = 0
  394. recall = 0
  395. f = 0
  396. if f_max < f:
  397. f_max = f
  398. p_max = precision
  399. r_max = recall
  400. num_prots_train = np.sum(labels_train[i, :])
  401. height = get_height(go, functions[i])
  402. results.append([functions[i], num_prots_train, height, f_max, p_max, r_max])
  403. results = pd.DataFrame(results)
  404. results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False)
  405. def function_centric_performance_backup(functions, preds, labels, labels_train):
  406. results = []
  407. preds = np.round(preds, 2)
  408. for i in range(len(functions)):
  409. f_max = 0
  410. p_max = 0
  411. r_max = 0
  412. x = list()
  413. y = list()
  414. total = 0
  415. for t in range(1, 100):
  416. threshold = t / 100.0
  417. predictions = (preds[i, :] > threshold).astype(np.int32)
  418. tp = np.sum(predictions * labels[i, :])
  419. fp = np.sum(predictions) - tp
  420. fn = np.sum(labels[i, :]) - tp
  421. if tp >0:
  422. sn = tp / (1.0 * np.sum(labels[i, :]))
  423. sp = np.sum((predictions ^ 1) * (labels[i, :] ^ 1))
  424. sp /= 1.0 * np.sum(labels[i, :] ^ 1)
  425. fpr = 1 - sp
  426. x.append(fpr)
  427. y.append(sn)
  428. precision = tp / (1.0 * (tp + fp))
  429. recall = tp / (1.0 * (tp + fn))
  430. f = 2 * precision * recall / (precision + recall)
  431. total +=1
  432. if f_max < f:
  433. f_max = f
  434. p_max = precision
  435. r_max = recall
  436. num_prots = np.sum(labels[i, :])
  437. num_prots_train = np.sum(labels_train[i,:])
  438. if total >1 :
  439. roc_auc = auc(x, y)
  440. else:
  441. roc_auc =0
  442. height = get_height(go , functions[i])
  443. results.append([functions[i], f_max, p_max, r_max, num_prots, num_prots_train, height,roc_auc])
  444. results = pd.DataFrame(results)
  445. #results.to_csv('new_results.txt' , sep='\t' , index = False)
  446. results.to_csv('Con_GodGanSeq_results_'+FUNCTION +'.txt', sep='\t', index=False)
  447. #results = np.array(results)
  448. #p_mean = (np.sum(results[:,2])) / len(functions)
  449. #r_mean = (np.sum(results[:,3])) / len(functions)
  450. #f_mean = (2*p_mean*r_mean)/(p_mean+r_mean)
  451. #roc_auc_mean = (np.sum(results[:,7])) / len(functions)
  452. #print('Function centric performance (macro) ' '%f %f %f %f' % (f_mean, p_mean, r_mean, roc_auc_mean))
  453. def micro_macro_function_centric_f1_backup(preds, labels):
  454. preds = np.round(preds, 2)
  455. m_f1_max = 0
  456. M_f1_max = 0
  457. for t in range(1, 100):
  458. threshold = t / 100.0
  459. predictions = (preds > threshold).astype(np.int32)
  460. m_tp = 0
  461. m_fp = 0
  462. m_fn = 0
  463. M_pr = 0
  464. M_rc = 0
  465. total = 0
  466. p_total = 0
  467. for i in range(len(preds)):
  468. tp = np.sum(predictions[i, :] * labels[i, :])
  469. fp = np.sum(predictions[i, :]) - tp
  470. fn = np.sum(labels[i, :]) - tp
  471. if tp == 0 and fp == 0 and fn == 0:
  472. continue
  473. total += 1
  474. if tp > 0:
  475. pr = tp / (1.0 * (tp + fp))
  476. rc = tp / (1.0 * (tp + fn))
  477. m_tp += tp
  478. m_fp += fp
  479. m_fn += fn
  480. M_pr += pr
  481. M_rc += rc
  482. p_total += 1
  483. if p_total == 0:
  484. continue
  485. if total > 0:
  486. m_tp /= total
  487. m_fn /= total
  488. m_fp /= total
  489. m_pr = m_tp / (1.0 * (m_tp + m_fp))
  490. m_rc = m_tp / (1.0 * (m_tp + m_fn))
  491. M_pr /= p_total
  492. M_rc /= total
  493. m_f1 = 2 * m_pr * m_rc / (m_pr + m_rc)
  494. M_f1 = 2 * M_pr * M_rc / (M_pr + M_rc)
  495. if m_f1 > m_f1_max:
  496. m_f1_max = m_f1
  497. m_pr_max = m_pr
  498. m_rc_max = m_rc
  499. if M_f1 > M_f1_max:
  500. M_f1_max = M_f1
  501. M_pr_max = M_pr
  502. M_rc_max = M_rc
  503. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  504. def micro_macro_function_centric_f1(preds, labels):
  505. preds = np.round(preds, 2)
  506. m_f1_max = 0
  507. M_f1_max = 0
  508. for t in range(1, 200):
  509. threshold = t / 200.0
  510. predictions = (preds > threshold).astype(np.int32)
  511. m_tp = 0
  512. m_fp = 0
  513. m_fn = 0
  514. M_pr = 0
  515. M_rc = 0
  516. for i in range(preds.shape[0]):
  517. tp = np.sum(predictions[i, :] * labels[i, :])
  518. fp = np.sum(predictions[i, :]) - tp
  519. fn = np.sum(labels[i, :]) - tp
  520. m_tp += tp
  521. m_fp += fp
  522. m_fn += fn
  523. if tp > 0:
  524. pr = 1.0 * tp / (1.0 * (tp + fp))
  525. rc = 1.0 * tp / (1.0 * (tp + fn))
  526. else:
  527. if fp == 0 and fn == 0:
  528. pr = 1
  529. rc = 1
  530. else:
  531. pr = 0
  532. rc = 0
  533. M_pr += pr
  534. M_rc += rc
  535. if m_tp > 0:
  536. m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp))
  537. m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn))
  538. m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc)
  539. else:
  540. if m_fp == 0 and m_fn == 0:
  541. m_pr = 1
  542. m_rc = 1
  543. m_f1 = 1
  544. else:
  545. m_pr = 0
  546. m_rc = 0
  547. m_f1 = 0
  548. M_pr /= preds.shape[0]
  549. M_rc /= preds.shape[0]
  550. if M_pr == 0 and M_rc == 0:
  551. M_f1 = 0
  552. else:
  553. M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc)
  554. if m_f1 > m_f1_max:
  555. m_f1_max = m_f1
  556. m_pr_max = m_pr
  557. m_rc_max = m_rc
  558. if M_f1 > M_f1_max:
  559. M_f1_max = M_f1
  560. M_pr_max = M_pr
  561. M_rc_max = M_rc
  562. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  563. def compute_roc(preds, labels):
  564. # Compute ROC curve and ROC area for each class
  565. fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
  566. roc_auc = auc(fpr, tpr)
  567. return roc_auc
  568. def compute_aupr(preds, labels):
  569. # Compute ROC curve and ROC area for each class
  570. pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten())
  571. pr_auc = auc(rc, pr)
  572. #pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' )
  573. M_pr_auc = 0
  574. return pr_auc, M_pr_auc
  575. def compute_mcc(preds, labels):
  576. # Compute ROC curve and ROC area for each class
  577. mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
  578. return mcc
  579. def compute_performance(preds, labels): #, gos):
  580. preds = np.round(preds, 2)
  581. f_max = 0
  582. p_max = 0
  583. r_max = 0
  584. t_max = 0
  585. for t in range(1, 100):
  586. threshold = t / 100.0
  587. predictions = (preds > threshold).astype(np.int32)
  588. total = 0
  589. f = 0.0
  590. p = 0.0
  591. r = 0.0
  592. p_total = 0
  593. for i in range(labels.shape[0]):
  594. tp = np.sum(predictions[i, :] * labels[i, :])
  595. fp = np.sum(predictions[i, :]) - tp
  596. fn = np.sum(labels[i, :]) - tp
  597. all_gos = set()
  598. #for go_id in gos[i]:
  599. # if go_id in all_functions:
  600. # all_gos |= get_anchestors(go, go_id)
  601. #all_gos.discard(GO_ID)
  602. #all_gos -= func_set
  603. #fn += len(all_gos)
  604. if tp == 0 and fp == 0 and fn == 0:
  605. continue
  606. total += 1
  607. if tp != 0:
  608. p_total += 1
  609. precision = tp / (1.0 * (tp + fp))
  610. recall = tp / (1.0 * (tp + fn))
  611. p += precision
  612. r += recall
  613. if p_total == 0:
  614. continue
  615. r /= total
  616. p /= p_total
  617. if p + r > 0:
  618. f = 2 * p * r / (p + r)
  619. if f_max < f:
  620. f_max = f
  621. p_max = p
  622. r_max = r
  623. t_max = threshold
  624. predictions_max = predictions
  625. return f_max, p_max, r_max, t_max, predictions_max
  626. def get_gos(pred):
  627. mdist = 1.0
  628. mgos = None
  629. for i in range(len(labels_gos)):
  630. labels, gos = labels_gos[i]
  631. dist = distance.cosine(pred, labels)
  632. if mdist > dist:
  633. mdist = dist
  634. mgos = gos
  635. return mgos
  636. if __name__ == '__main__':
  637. main()