You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

conditional_new_nn_hirearchical_seq.py 30KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. #!/usr/bin/env python
  2. """
  3. parts of codes are based on the work in https://github.com/bio-ontology-research-group/deepgo.
  4. """
  5. from __future__ import division
  6. import logging
  7. import sys
  8. import time
  9. from collections import deque
  10. from multiprocessing import Pool
  11. import click as ck
  12. import numpy as np
  13. import pandas as pd
  14. import tensorflow as tf
  15. from keras import backend as K
  16. from keras.callbacks import EarlyStopping, ModelCheckpoint
  17. from keras.layers import (
  18. Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D,
  19. Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout)
  20. from keras.losses import binary_crossentropy
  21. from keras.models import Sequential, Model, load_model
  22. from keras.preprocessing import sequence
  23. from scipy.spatial import distance
  24. from sklearn.metrics import log_loss
  25. from sklearn.metrics import roc_curve, auc, matthews_corrcoef
  26. from keras.layers import Lambda
  27. from sklearn.metrics import precision_recall_curve
  28. from utils import (
  29. get_gene_ontology,
  30. get_go_set,
  31. get_anchestors,
  32. get_parents,
  33. DataGenerator,
  34. FUNC_DICT,
  35. get_height,
  36. get_ipro)
  37. from conditional_wgan_wrapper_post import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new
  38. config = tf.ConfigProto()
  39. config.gpu_options.allow_growth = True
  40. sess = tf.Session(config=config)
  41. K.set_session(sess)
  42. logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
  43. sys.setrecursionlimit(100000)
  44. DATA_ROOT = 'data/swiss/'
  45. MAXLEN = 1000
  46. REPLEN = 256
  47. ind = 0
  48. @ck.command()
  49. @ck.option(
  50. '--function',
  51. default='bp',
  52. help='Ontology id (mf, bp, cc)')
  53. @ck.option(
  54. '--device',
  55. default='gpu:0',
  56. help='GPU or CPU device id')
  57. @ck.option(
  58. '--org',
  59. default= None,
  60. help='Organism id for filtering test set')
  61. @ck.option('--train',default = True, is_flag=True)
  62. @ck.option('--param', default=0, help='Param index 0-7')
  63. def main(function, device, org, train, param):
  64. global FUNCTION
  65. FUNCTION = function
  66. global GO_ID
  67. GO_ID = FUNC_DICT[FUNCTION]
  68. global go
  69. go = get_gene_ontology('go.obo')
  70. global ORG
  71. ORG = org
  72. func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl')
  73. global functions
  74. functions = func_df['functions'].values
  75. global func_set
  76. func_set = set(functions)
  77. global all_functions
  78. all_functions = get_go_set(go, GO_ID)
  79. logging.info('Functions: %s %d' % (FUNCTION, len(functions)))
  80. if ORG is not None:
  81. logging.info('Organism %s' % ORG)
  82. global go_indexes
  83. go_indexes = dict()
  84. for ind, go_id in enumerate(functions):
  85. go_indexes[go_id] = ind
  86. global node_names
  87. node_names = set()
  88. with tf.device('/' + device):
  89. params = {
  90. 'fc_output': 1024,
  91. 'learning_rate': 0.001,
  92. 'embedding_dims': 128,
  93. 'embedding_dropout': 0.2,
  94. 'nb_conv': 1,
  95. 'nb_dense': 1,
  96. 'filter_length': 128,
  97. 'nb_filter': 32,
  98. 'pool_length': 64,
  99. 'stride': 32
  100. }
  101. model(params, is_train=train)
  102. def load_data():
  103. df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
  104. n = len(df)
  105. index = df.index.values
  106. valid_n = int(n * 0.8)
  107. train_df = df.loc[index[:valid_n]]
  108. valid_df = df.loc[index[valid_n:]]
  109. test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
  110. print( test_df['orgs'] )
  111. if ORG is not None:
  112. logging.info('Unfiltered test size: %d' % len(test_df))
  113. test_df = test_df[test_df['orgs'] == ORG]
  114. logging.info('Filtered test size: %d' % len(test_df))
  115. # Filter by type
  116. # org_df = pd.read_pickle('data/prokaryotes.pkl')
  117. # orgs = org_df['orgs']
  118. # test_df = test_df[test_df['orgs'].isin(orgs)]
  119. def reshape(values):
  120. values = np.hstack(values).reshape(
  121. len(values), len(values[0]))
  122. return values
  123. def normalize_minmax(values):
  124. mn = np.min(values)
  125. mx = np.max(values)
  126. if mx - mn != 0.0:
  127. return (values - mn) / (mx - mn)
  128. return values - mn
  129. def get_values(data_frame):
  130. print(data_frame['labels'].values.shape)
  131. labels = reshape(data_frame['labels'].values)
  132. ngrams = sequence.pad_sequences(
  133. data_frame['ngrams'].values, maxlen=MAXLEN)
  134. ngrams = reshape(ngrams)
  135. rep = reshape(data_frame['embeddings'].values)
  136. data = ngrams
  137. return data, labels
  138. train = get_values(train_df)
  139. valid = get_values(valid_df)
  140. test = get_values(test_df)
  141. return train, valid, test, train_df, valid_df, test_df
  142. def get_feature_model(params):
  143. embedding_dims = params['embedding_dims']
  144. max_features = 8001
  145. model = Sequential()
  146. model.add(Embedding(
  147. max_features,
  148. embedding_dims,
  149. input_length=MAXLEN))
  150. model.add(SpatialDropout1D(0.4))
  151. for i in range(params['nb_conv']):
  152. model.add(Conv1D(
  153. activation="relu",
  154. padding="valid",
  155. strides=1,
  156. filters=params['nb_filter'],
  157. kernel_size=params['filter_length']))
  158. model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length']))
  159. model.add(Flatten())
  160. return model
  161. def merge_outputs(outputs, name):
  162. if len(outputs) == 1:
  163. return outputs[0]
  164. ## return merge(outputs, mode='concat', name=name, concat_axis=1)
  165. return Concatenate(axis=1, name=name)(outputs)
  166. def merge_nets(nets, name):
  167. if len(nets) == 1:
  168. return nets[0]
  169. ## return merge(nets, mode='sum', name=name)
  170. return Add(name=name)(nets)
  171. def get_node_name(go_id, unique=False):
  172. name = go_id.split(':')[1]
  173. if not unique:
  174. return name
  175. if name not in node_names:
  176. node_names.add(name)
  177. return name
  178. i = 1
  179. while (name + '_' + str(i)) in node_names:
  180. i += 1
  181. name = name + '_' + str(i)
  182. node_names.add(name)
  183. return name
  184. def get_layers(inputs):
  185. q = deque()
  186. layers = {}
  187. name = get_node_name(GO_ID)
  188. layers[GO_ID] = {'net': inputs}
  189. for node_id in go[GO_ID]['children']:
  190. if node_id in func_set:
  191. q.append((node_id, inputs))
  192. while len(q) > 0:
  193. node_id, net = q.popleft()
  194. parent_nets = [inputs]
  195. # for p_id in get_parents(go, node_id):
  196. # if p_id in func_set:
  197. # parent_nets.append(layers[p_id]['net'])
  198. # if len(parent_nets) > 1:
  199. # name = get_node_name(node_id) + '_parents'
  200. # net = merge(
  201. # parent_nets, mode='concat', concat_axis=1, name=name)
  202. name = get_node_name(node_id)
  203. net, output = get_function_node(name, inputs)
  204. if node_id not in layers:
  205. layers[node_id] = {'net': net, 'output': output}
  206. for n_id in go[node_id]['children']:
  207. if n_id in func_set and n_id not in layers:
  208. ok = True
  209. for p_id in get_parents(go, n_id):
  210. if p_id in func_set and p_id not in layers:
  211. ok = False
  212. if ok:
  213. q.append((n_id, net))
  214. for node_id in functions:
  215. childs = set(go[node_id]['children']).intersection(func_set)
  216. if len(childs) > 0:
  217. outputs = [layers[node_id]['output']]
  218. for ch_id in childs:
  219. outputs.append(layers[ch_id]['output'])
  220. name = get_node_name(node_id) + '_max'
  221. ## layers[node_id]['output'] = merge(
  222. ## outputs, mode='max', name=name)
  223. layers[node_id]['output'] = Maximum(name=name)(outputs)
  224. return layers
  225. def get_function_node(name, inputs):
  226. output_name = name + '_out'
  227. # net = Dense(256, name=name, activation='relu')(inputs)
  228. output = Dense(1, name=output_name, activation='sigmoid')(inputs)
  229. return output, output
  230. def get_generator(params, n_classes):
  231. inputs = Input(shape=(MAXLEN,), dtype='int32', name='input1')
  232. feature_model = get_feature_model(params)(inputs)
  233. net = Dense(300, activation='relu')(feature_model)
  234. net = BatchNormalization()(net)
  235. layers = get_layers(net)
  236. output_models = []
  237. for i in range(len(functions)):
  238. output_models.append(layers[functions[i]]['output'])
  239. net = Concatenate(axis=1)(output_models)
  240. output = Dense(n_classes, activation='sigmoid')(net)
  241. model = Model(inputs=inputs, outputs=output)
  242. return model
  243. def get_discriminator(params, n_classes, dropout_rate=0.5):
  244. inputs = Input(shape=(n_classes, ))
  245. inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2')
  246. x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2)
  247. x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', activation ='relu', strides=1)(x2)
  248. x2 = Lambda(lambda x: K.squeeze(x, 2))(x2)
  249. #for i in range(params['nb_conv']):
  250. # x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2)
  251. #x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2)
  252. #x2 = Flatten()(x2)
  253. size = 40
  254. x = inputs
  255. x = Dropout(dropout_rate)(x)
  256. x = Dense(size)(x)
  257. x = BatchNormalization()(x)
  258. x = Activation('relu')(x)
  259. size = 40
  260. x2 = Dropout(dropout_rate)(x2)
  261. x2 = Dense(size)(x2)
  262. x2 = BatchNormalization()(x2)
  263. x2 = Activation('relu')(x2)
  264. x = Concatenate(axis =1 , name = 'merged2')([x, x2])
  265. layer_sizes = [80, 40,30]
  266. for size in layer_sizes:
  267. x = Dropout(dropout_rate)(x)
  268. x = Dense(size)(x)
  269. x = BatchNormalization()(x)
  270. x = Activation('relu')(x)
  271. outputs = Dense(1)(x)
  272. model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator')
  273. return model
  274. def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10):
  275. generator = get_generator(params, nb_classes)
  276. discriminator = get_discriminator(params, nb_classes)
  277. generator_model, discriminator_model = \
  278. WGAN_wrapper(generator=generator,
  279. discriminator=discriminator,
  280. generator_input_shape=(MAXLEN,),
  281. discriminator_input_shape=(nb_classes,),
  282. discriminator_input_shape2 = (MAXLEN, ),
  283. batch_size=batch_size,
  284. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
  285. logging.info('Compilation finished')
  286. return generator_model, discriminator_model
  287. def train_wgan(generator_model, discriminator_model, batch_size, epochs,
  288. x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path,
  289. TRAINING_RATIO=10, N_WARM_UP=0):
  290. BATCH_SIZE = batch_size
  291. N_EPOCH = epochs
  292. positive_y = np.ones((batch_size, 1), dtype=np.float32)
  293. zero_y = positive_y * 0
  294. negative_y = -positive_y
  295. positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32)
  296. dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
  297. positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 )
  298. positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 )
  299. #positive_enable_train = np.ones((1, batch_size),dtype = np.float32 )
  300. #positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 )
  301. best_validation_loss = None
  302. for epoch in range(N_EPOCH):
  303. # np.random.shuffle(X_train)
  304. print("Epoch: ", epoch)
  305. print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE))
  306. discriminator_loss = []
  307. generator_loss = []
  308. minibatches_size = BATCH_SIZE * TRAINING_RATIO
  309. shuffled_indexes = np.random.permutation(x_train.shape[0])
  310. shuffled_indexes_2 = np.random.permutation(x_train.shape[0])
  311. for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))):
  312. batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size]
  313. batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size]
  314. x = x_train[batch_indexes]
  315. y = y_train[batch_indexes]
  316. y_2 = y_train[batch_indexes_2]
  317. x_2 = x_train[batch_indexes_2]
  318. if epoch < N_WARM_UP:
  319. for j in range(TRAINING_RATIO):
  320. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  321. y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  322. generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y]))
  323. else:
  324. for j in range(TRAINING_RATIO):
  325. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  326. y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  327. x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  328. # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
  329. noise = x_batch
  330. #print(sum(y_batch_2))
  331. discriminator_loss.append(discriminator_model.train_on_batch(
  332. [y_batch_2, noise, x_batch_2 ],
  333. [positive_y, negative_y, dummy_y]))
  334. generator_loss.append(generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y]))
  335. # Still needs some code to display losses from the generator and discriminator, progress bars, etc.
  336. predicted_y_train, _ = generator_model.predict([x_train , positive_full_enable_train], batch_size=BATCH_SIZE)
  337. predicted_y_val, _ = generator_model.predict([ x_val , positive_full_enable_val ], batch_size=BATCH_SIZE)
  338. #print(sum(sum(positive_full_enable_train)))
  339. #print(predicted_y_train)
  340. train_loss = log_loss(y_train, predicted_y_train)
  341. val_loss = log_loss(y_val, predicted_y_val)
  342. print("train loss: {:.4f}, validation loss: {:.4f}, discriminator loss: {:.4f}".format(
  343. train_loss, val_loss,
  344. (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0]))
  345. if best_validation_loss is None or best_validation_loss > val_loss:
  346. print('\nEpoch %05d: improved from %0.5f,'
  347. ' saving model to %s and %s'
  348. % (epoch + 1, val_loss, generator_model_path, discriminator_model_path))
  349. best_validation_loss = val_loss
  350. generator_model.save(generator_model_path, overwrite=True)
  351. discriminator_model.save(discriminator_model_path, overwrite=True)
  352. def model(params, batch_size=20, nb_epoch=40, is_train=True):
  353. # set parameters:
  354. nb_classes = len(functions)
  355. start_time = time.time()
  356. logging.info("Loading Data")
  357. train, val, test, train_df, valid_df, test_df = load_data()
  358. train_df = pd.concat([train_df, valid_df])
  359. test_gos = test_df['gos'].values
  360. train_data, train_labels = train
  361. val_data, val_labels = val
  362. test_data, test_labels = test
  363. logging.info("Data loaded in %d sec" % (time.time() - start_time))
  364. logging.info("Training data size: %d" % len(train_data))
  365. logging.info("Validation data size: %d" % len(val_data))
  366. logging.info("Test data size: %d" % len(test_data))
  367. generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5'
  368. discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5'
  369. logging.info('Starting training the model')
  370. train_generator = DataGenerator(batch_size, nb_classes)
  371. train_generator.fit(train_data, train_labels)
  372. valid_generator = DataGenerator(batch_size, nb_classes)
  373. valid_generator.fit(val_data, val_labels)
  374. test_generator = DataGenerator(batch_size, nb_classes)
  375. test_generator.fit(test_data, test_labels)
  376. if is_train:
  377. generator_model, discriminator_model = get_model(params, nb_classes, batch_size)
  378. train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch,
  379. x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels,
  380. generator_model_path=generator_model_path,
  381. discriminator_model_path=discriminator_model_path)
  382. logging.info('Loading best model')
  383. model = load_model(generator_model_path,
  384. custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new,
  385. 'wasserstein_loss': wasserstein_loss})
  386. logging.info('Predicting')
  387. preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0]
  388. logging.info('Computing performance')
  389. f, p, r, t, preds_max = compute_performance(preds, test_labels) #, test_gos)
  390. roc_auc = compute_roc(preds, test_labels)
  391. mcc = compute_mcc(preds_max, test_labels)
  392. aupr , _ = compute_aupr(preds, test_labels)
  393. m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T)
  394. logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f))
  395. logging.info('ROC AUC: \t %f ' % (roc_auc, ))
  396. logging.info('MCC: \t %f ' % (mcc, ))
  397. logging.info('AUPR: \t %f ' % (aupr, ))
  398. logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) )
  399. logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) )
  400. function_centric_performance(functions, preds.T, test_labels.T, train_labels.T)
  401. def load_prot_ipro():
  402. proteins = list()
  403. ipros = list()
  404. with open(DATA_ROOT + 'swissprot_ipro.tab') as f:
  405. for line in f:
  406. it = line.strip().split('\t')
  407. if len(it) != 3:
  408. continue
  409. prot = it[1]
  410. iprs = it[2].split(';')
  411. proteins.append(prot)
  412. ipros.append(iprs)
  413. return pd.DataFrame({'proteins': proteins, 'ipros': ipros})
  414. def performanc_by_interpro():
  415. pred_df = pd.read_pickle(DATA_ROOT + 'test-' + FUNCTION + '-preds.pkl')
  416. ipro_df = load_prot_ipro()
  417. df = pred_df.merge(ipro_df, on='proteins', how='left')
  418. ipro = get_ipro()
  419. def reshape(values):
  420. values = np.hstack(values).reshape(
  421. len(values), len(values[0]))
  422. return values
  423. for ipro_id in ipro:
  424. if len(ipro[ipro_id]['parents']) > 0:
  425. continue
  426. labels = list()
  427. predictions = list()
  428. gos = list()
  429. for i, row in df.iterrows():
  430. if not isinstance(row['ipros'], list):
  431. continue
  432. if ipro_id in row['ipros']:
  433. labels.append(row['labels'])
  434. predictions.append(row['predictions'])
  435. gos.append(row['gos'])
  436. pr = 0
  437. rc = 0
  438. total = 0
  439. p_total = 0
  440. for i in range(len(labels)):
  441. tp = np.sum(labels[i] * predictions[i])
  442. fp = np.sum(predictions[i]) - tp
  443. fn = np.sum(labels[i]) - tp
  444. all_gos = set()
  445. for go_id in gos[i]:
  446. if go_id in all_functions:
  447. all_gos |= get_anchestors(go, go_id)
  448. all_gos.discard(GO_ID)
  449. all_gos -= func_set
  450. fn += len(all_gos)
  451. if tp == 0 and fp == 0 and fn == 0:
  452. continue
  453. total += 1
  454. if tp != 0:
  455. p_total += 1
  456. precision = tp / (1.0 * (tp + fp))
  457. recall = tp / (1.0 * (tp + fn))
  458. pr += precision
  459. rc += recall
  460. if total > 0 and p_total > 0:
  461. rc /= total
  462. pr /= p_total
  463. if pr + rc > 0:
  464. f = 2 * pr * rc / (pr + rc)
  465. logging.info('%s\t%d\t%f\t%f\t%f' % (
  466. ipro_id, len(labels), f, pr, rc))
  467. def function_centric_performance(functions, preds, labels, labels_train):
  468. results = []
  469. preds = np.round(preds, 2)
  470. for i in range(preds.shape[0]):
  471. f_max = 0
  472. p_max = 0
  473. r_max = 0
  474. for t in range(1, 100):
  475. threshold = t / 100.0
  476. predictions = (preds[i, :] > threshold).astype(np.int32)
  477. tp = np.sum(predictions * labels[i, :])
  478. fp = np.sum(predictions) - tp
  479. fn = np.sum(labels[i, :]) - tp
  480. if tp > 0:
  481. precision = tp / (1.0 * (tp + fp))
  482. recall = tp / (1.0 * (tp + fn))
  483. f = 2 * precision * recall / (precision + recall)
  484. else:
  485. if fp == 0 and fn == 0:
  486. precision = 1
  487. recall = 1
  488. f = 1
  489. else:
  490. precision = 0
  491. recall = 0
  492. f = 0
  493. if f_max < f:
  494. f_max = f
  495. p_max = precision
  496. r_max = recall
  497. num_prots_train = np.sum(labels_train[i, :])
  498. height = get_height(go, functions[i])
  499. results.append([functions[i], num_prots_train, height, f_max, p_max, r_max])
  500. results = pd.DataFrame(results)
  501. results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False)
  502. def function_centric_performance_backup(functions, preds, labels, labels_train):
  503. results = []
  504. preds = np.round(preds, 2)
  505. for i in range(len(functions)):
  506. f_max = 0
  507. p_max = 0
  508. r_max = 0
  509. x = list()
  510. y = list()
  511. total = 0
  512. for t in range(1, 100):
  513. threshold = t / 100.0
  514. predictions = (preds[i, :] > threshold).astype(np.int32)
  515. tp = np.sum(predictions * labels[i, :])
  516. fp = np.sum(predictions) - tp
  517. fn = np.sum(labels[i, :]) - tp
  518. if tp >0:
  519. sn = tp / (1.0 * np.sum(labels[i, :]))
  520. sp = np.sum((predictions ^ 1) * (labels[i, :] ^ 1))
  521. sp /= 1.0 * np.sum(labels[i, :] ^ 1)
  522. fpr = 1 - sp
  523. x.append(fpr)
  524. y.append(sn)
  525. precision = tp / (1.0 * (tp + fp))
  526. recall = tp / (1.0 * (tp + fn))
  527. f = 2 * precision * recall / (precision + recall)
  528. total +=1
  529. if f_max < f:
  530. f_max = f
  531. p_max = precision
  532. r_max = recall
  533. num_prots = np.sum(labels[i, :])
  534. num_prots_train = np.sum(labels_train[i,:])
  535. if total >1 :
  536. roc_auc = auc(x, y)
  537. else:
  538. roc_auc =0
  539. height = get_height(go , functions[i])
  540. results.append([functions[i], f_max, p_max, r_max, num_prots, num_prots_train, height,roc_auc])
  541. results = pd.DataFrame(results)
  542. #results.to_csv('new_results.txt' , sep='\t' , index = False)
  543. results.to_csv('Con_GodGanSeq_results_'+FUNCTION +'.txt', sep='\t', index=False)
  544. #results = np.array(results)
  545. #p_mean = (np.sum(results[:,2])) / len(functions)
  546. #r_mean = (np.sum(results[:,3])) / len(functions)
  547. #f_mean = (2*p_mean*r_mean)/(p_mean+r_mean)
  548. #roc_auc_mean = (np.sum(results[:,7])) / len(functions)
  549. #print('Function centric performance (macro) ' '%f %f %f %f' % (f_mean, p_mean, r_mean, roc_auc_mean))
  550. def micro_macro_function_centric_f1_backup(preds, labels):
  551. preds = np.round(preds, 2)
  552. m_f1_max = 0
  553. M_f1_max = 0
  554. for t in range(1, 100):
  555. threshold = t / 100.0
  556. predictions = (preds > threshold).astype(np.int32)
  557. m_tp = 0
  558. m_fp = 0
  559. m_fn = 0
  560. M_pr = 0
  561. M_rc = 0
  562. total = 0
  563. p_total = 0
  564. for i in range(len(preds)):
  565. tp = np.sum(predictions[i, :] * labels[i, :])
  566. fp = np.sum(predictions[i, :]) - tp
  567. fn = np.sum(labels[i, :]) - tp
  568. if tp == 0 and fp == 0 and fn == 0:
  569. continue
  570. total += 1
  571. if tp > 0:
  572. pr = tp / (1.0 * (tp + fp))
  573. rc = tp / (1.0 * (tp + fn))
  574. m_tp += tp
  575. m_fp += fp
  576. m_fn += fn
  577. M_pr += pr
  578. M_rc += rc
  579. p_total += 1
  580. if p_total == 0:
  581. continue
  582. if total > 0:
  583. m_tp /= total
  584. m_fn /= total
  585. m_fp /= total
  586. m_pr = m_tp / (1.0 * (m_tp + m_fp))
  587. m_rc = m_tp / (1.0 * (m_tp + m_fn))
  588. M_pr /= p_total
  589. M_rc /= total
  590. m_f1 = 2 * m_pr * m_rc / (m_pr + m_rc)
  591. M_f1 = 2 * M_pr * M_rc / (M_pr + M_rc)
  592. if m_f1 > m_f1_max:
  593. m_f1_max = m_f1
  594. m_pr_max = m_pr
  595. m_rc_max = m_rc
  596. if M_f1 > M_f1_max:
  597. M_f1_max = M_f1
  598. M_pr_max = M_pr
  599. M_rc_max = M_rc
  600. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  601. def micro_macro_function_centric_f1(preds, labels):
  602. preds = np.round(preds, 2)
  603. m_f1_max = 0
  604. M_f1_max = 0
  605. for t in range(1, 200):
  606. threshold = t / 200.0
  607. predictions = (preds > threshold).astype(np.int32)
  608. m_tp = 0
  609. m_fp = 0
  610. m_fn = 0
  611. M_pr = 0
  612. M_rc = 0
  613. for i in range(preds.shape[0]):
  614. tp = np.sum(predictions[i, :] * labels[i, :])
  615. fp = np.sum(predictions[i, :]) - tp
  616. fn = np.sum(labels[i, :]) - tp
  617. m_tp += tp
  618. m_fp += fp
  619. m_fn += fn
  620. if tp > 0:
  621. pr = 1.0 * tp / (1.0 * (tp + fp))
  622. rc = 1.0 * tp / (1.0 * (tp + fn))
  623. else:
  624. if fp == 0 and fn == 0:
  625. pr = 1
  626. rc = 1
  627. else:
  628. pr = 0
  629. rc = 0
  630. M_pr += pr
  631. M_rc += rc
  632. if m_tp > 0:
  633. m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp))
  634. m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn))
  635. m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc)
  636. else:
  637. if m_fp == 0 and m_fn == 0:
  638. m_pr = 1
  639. m_rc = 1
  640. m_f1 = 1
  641. else:
  642. m_pr = 0
  643. m_rc = 0
  644. m_f1 = 0
  645. M_pr /= preds.shape[0]
  646. M_rc /= preds.shape[0]
  647. if M_pr == 0 and M_rc == 0:
  648. M_f1 = 0
  649. else:
  650. M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc)
  651. if m_f1 > m_f1_max:
  652. m_f1_max = m_f1
  653. m_pr_max = m_pr
  654. m_rc_max = m_rc
  655. if M_f1 > M_f1_max:
  656. M_f1_max = M_f1
  657. M_pr_max = M_pr
  658. M_rc_max = M_rc
  659. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  660. def compute_roc(preds, labels):
  661. # Compute ROC curve and ROC area for each class
  662. fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
  663. roc_auc = auc(fpr, tpr)
  664. return roc_auc
  665. def compute_aupr(preds, labels):
  666. # Compute ROC curve and ROC area for each class
  667. pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten())
  668. pr_auc = auc(rc, pr)
  669. #pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' )
  670. M_pr_auc = 0
  671. return pr_auc, M_pr_auc
  672. def compute_mcc(preds, labels):
  673. # Compute ROC curve and ROC area for each class
  674. mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
  675. return mcc
  676. def compute_performance(preds, labels): #, gos):
  677. preds = np.round(preds, 2)
  678. f_max = 0
  679. p_max = 0
  680. r_max = 0
  681. t_max = 0
  682. for t in range(1, 100):
  683. threshold = t / 100.0
  684. predictions = (preds > threshold).astype(np.int32)
  685. total = 0
  686. f = 0.0
  687. p = 0.0
  688. r = 0.0
  689. p_total = 0
  690. for i in range(labels.shape[0]):
  691. tp = np.sum(predictions[i, :] * labels[i, :])
  692. fp = np.sum(predictions[i, :]) - tp
  693. fn = np.sum(labels[i, :]) - tp
  694. all_gos = set()
  695. #for go_id in gos[i]:
  696. # if go_id in all_functions:
  697. # all_gos |= get_anchestors(go, go_id)
  698. #all_gos.discard(GO_ID)
  699. #all_gos -= func_set
  700. #fn += len(all_gos)
  701. if tp == 0 and fp == 0 and fn == 0:
  702. continue
  703. total += 1
  704. if tp != 0:
  705. p_total += 1
  706. precision = tp / (1.0 * (tp + fp))
  707. recall = tp / (1.0 * (tp + fn))
  708. p += precision
  709. r += recall
  710. if p_total == 0:
  711. continue
  712. r /= total
  713. p /= p_total
  714. if p + r > 0:
  715. f = 2 * p * r / (p + r)
  716. if f_max < f:
  717. f_max = f
  718. p_max = p
  719. r_max = r
  720. t_max = threshold
  721. predictions_max = predictions
  722. return f_max, p_max, r_max, t_max, predictions_max
  723. def get_gos(pred):
  724. mdist = 1.0
  725. mgos = None
  726. for i in range(len(labels_gos)):
  727. labels, gos = labels_gos[i]
  728. dist = distance.cosine(pred, labels)
  729. if mdist > dist:
  730. mdist = dist
  731. mgos = gos
  732. return mgos
  733. def compute_similarity_performance(train_df, test_df, preds):
  734. logging.info("Computing similarity performance")
  735. logging.info("Training data size %d" % len(train_df))
  736. train_labels = train_df['labels'].values
  737. train_gos = train_df['gos'].values
  738. global labels_gos
  739. labels_gos = zip(train_labels, train_gos)
  740. p = Pool(64)
  741. pred_gos = p.map(get_gos, preds)
  742. total = 0
  743. p = 0.0
  744. r = 0.0
  745. f = 0.0
  746. test_gos = test_df['gos'].values
  747. for gos, tgos in zip(pred_gos, test_gos):
  748. preds = set()
  749. test = set()
  750. for go_id in gos:
  751. if go_id in all_functions:
  752. preds |= get_anchestors(go, go_id)
  753. for go_id in tgos:
  754. if go_id in all_functions:
  755. test |= get_anchestors(go, go_id)
  756. tp = len(preds.intersection(test))
  757. fp = len(preds - test)
  758. fn = len(test - preds)
  759. if tp == 0 and fp == 0 and fn == 0:
  760. continue
  761. total += 1
  762. if tp != 0:
  763. precision = tp / (1.0 * (tp + fp))
  764. recall = tp / (1.0 * (tp + fn))
  765. p += precision
  766. r += recall
  767. f += 2 * precision * recall / (precision + recall)
  768. return f / total, p / total, r / total
  769. def print_report(report, go_id):
  770. with open(DATA_ROOT + 'reports.txt', 'a') as f:
  771. f.write('Classification report for ' + go_id + '\n')
  772. f.write(report + '\n')
  773. if __name__ == '__main__':
  774. main()