You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

conditional_new_nn_hirearchical_seq_mt.py 31KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917
  1. #!/usr/bin/env python
  2. """
  3. """
  4. from __future__ import division
  5. import logging
  6. import sys
  7. import time
  8. from collections import deque
  9. from multiprocessing import Pool
  10. import click as ck
  11. import numpy as np
  12. import pandas as pd
  13. import tensorflow as tf
  14. from keras import backend as K
  15. from keras.callbacks import EarlyStopping, ModelCheckpoint
  16. from keras.layers import (
  17. Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D,
  18. Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout)
  19. from keras.losses import binary_crossentropy
  20. from keras.models import Sequential, Model, load_model
  21. from keras.preprocessing import sequence
  22. from scipy.spatial import distance
  23. from sklearn.metrics import log_loss
  24. from sklearn.metrics import roc_curve, auc, matthews_corrcoef
  25. from keras.layers import Lambda
  26. from sklearn.metrics import precision_recall_curve
  27. from utils import (
  28. get_gene_ontology,
  29. get_go_set,
  30. get_anchestors,
  31. get_parents,
  32. DataGenerator,
  33. FUNC_DICT,
  34. get_height,
  35. get_ipro)
  36. from conditional_wgan_wrapper_post import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new
  37. config = tf.ConfigProto()
  38. config.gpu_options.allow_growth = True
  39. sess = tf.Session(config=config)
  40. K.set_session(sess)
  41. logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
  42. sys.setrecursionlimit(100000)
  43. DATA_ROOT = 'data/swiss/'
  44. MAXLEN = 258 #1000
  45. REPLEN = 256
  46. ind = 0
  47. @ck.command()
  48. @ck.option(
  49. '--function',
  50. default='bp',
  51. help='Ontology id (mf, bp, cc)')
  52. @ck.option(
  53. '--device',
  54. default='gpu:0',
  55. help='GPU or CPU device id')
  56. @ck.option(
  57. '--org',
  58. default= None,
  59. help='Organism id for filtering test set')
  60. @ck.option('--train',default = True, is_flag=True)
  61. @ck.option('--param', default=0, help='Param index 0-7')
  62. def main(function, device, org, train, param):
  63. global FUNCTION
  64. FUNCTION = function
  65. global GO_ID
  66. GO_ID = FUNC_DICT[FUNCTION]
  67. global go
  68. go = get_gene_ontology('go.obo')
  69. global ORG
  70. ORG = org
  71. func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl')
  72. global functions
  73. functions = func_df['functions'].values
  74. global func_set
  75. func_set = set(functions)
  76. global all_functions
  77. all_functions = get_go_set(go, GO_ID)
  78. logging.info('Functions: %s %d' % (FUNCTION, len(functions)))
  79. if ORG is not None:
  80. logging.info('Organism %s' % ORG)
  81. global go_indexes
  82. go_indexes = dict()
  83. for ind, go_id in enumerate(functions):
  84. go_indexes[go_id] = ind
  85. global node_names
  86. node_names = set()
  87. with tf.device('/' + device):
  88. params = {
  89. 'fc_output': 1024,
  90. 'learning_rate': 0.001,
  91. 'embedding_dims': 128,
  92. 'embedding_dropout': 0.2,
  93. 'nb_conv': 1,
  94. 'nb_dense': 1,
  95. 'filter_length': 128,
  96. 'nb_filter': 32,
  97. 'pool_length': 64,
  98. 'stride': 32
  99. }
  100. model(params, is_train=train)
  101. #dims = [64, 128, 256, 512]
  102. #nb_filters = [16, 32, 64, 128]
  103. #nb_convs = [1, 2, 3, 4]
  104. #nb_dense = [1, 2, 3, 4]
  105. #for i in range(param * 32, param * 32 + 32):
  106. # dim = i % 4
  107. # i = i / 4
  108. # nb_fil = i % 4
  109. # i /= 4
  110. # conv = i % 4
  111. # i /= 4
  112. # den = i
  113. # params['embedding_dims'] = dims[dim]
  114. # params['nb_filter'] = nb_filters[nb_fil]
  115. # params['nb_conv'] = nb_convs[conv]
  116. # params['nb_dense'] = nb_dense[den]
  117. # performanc_by_interpro()
  118. def load_data2():
  119. all_data_x_fn = 'data2/all_data_X.csv'
  120. all_data_x = pd.read_csv(all_data_x_fn, sep='\t', header=0, index_col=0)
  121. all_proteins_train = [p.replace('"', '') for p in all_data_x.index]
  122. all_data_x.index = all_proteins_train
  123. all_data_y_fn = 'data2/all_data_Y.csv'
  124. all_data_y = pd.read_csv(all_data_y_fn, sep='\t', header=0, index_col=0)
  125. branch = pd.read_csv('data2/'+FUNCTION +'_branches.txt', sep='\t', header=0, index_col=0)
  126. all_x = all_data_x.values
  127. branches = [p for p in branch.index.tolist() if p in all_data_y.columns.tolist()]
  128. t= pd.DataFrame(all_data_y, columns=branches)
  129. all_y = t.values
  130. number_of_test = int(np.ceil(0.2 * len(all_x)))
  131. index = np.random.rand(1,number_of_test)
  132. index_test = [int(p) for p in np.ceil(index*len(all_x))[0] ]
  133. index_train = [p for p in range(len(all_x)) if p not in index_test]
  134. train_data = all_x[index_train, : ] #[ :20000, : ]
  135. test_data = all_x[index_test, : ] #[20000: , : ]
  136. train_labels = all_y[index_train, : ] #[ :20000, : ]
  137. test_labels = all_y[index_test, :] #[20000: , : ]
  138. val_data = test_data
  139. val_labels = test_labels
  140. #print(sum(sum(train_labels)))
  141. #print(train_data.shape)
  142. print(train_labels.shape)
  143. print(test_labels.shape)
  144. return train_data, train_labels, test_data, test_labels, val_data, val_labels
  145. def load_data():
  146. df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
  147. n = len(df)
  148. index = df.index.values
  149. valid_n = int(n * 0.8)
  150. train_df = df.loc[index[:valid_n]]
  151. valid_df = df.loc[index[valid_n:]]
  152. test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
  153. print( test_df['orgs'] )
  154. if ORG is not None:
  155. logging.info('Unfiltered test size: %d' % len(test_df))
  156. test_df = test_df[test_df['orgs'] == ORG]
  157. logging.info('Filtered test size: %d' % len(test_df))
  158. # Filter by type
  159. # org_df = pd.read_pickle('data/prokaryotes.pkl')
  160. # orgs = org_df['orgs']
  161. # test_df = test_df[test_df['orgs'].isin(orgs)]
  162. def reshape(values):
  163. values = np.hstack(values).reshape(
  164. len(values), len(values[0]))
  165. return values
  166. def normalize_minmax(values):
  167. mn = np.min(values)
  168. mx = np.max(values)
  169. if mx - mn != 0.0:
  170. return (values - mn) / (mx - mn)
  171. return values - mn
  172. def get_values(data_frame):
  173. print(data_frame['labels'].values.shape)
  174. labels = reshape(data_frame['labels'].values)
  175. ngrams = sequence.pad_sequences(
  176. data_frame['ngrams'].values, maxlen=MAXLEN)
  177. ngrams = reshape(ngrams)
  178. rep = reshape(data_frame['embeddings'].values)
  179. data = ngrams
  180. return data, labels
  181. train = get_values(train_df)
  182. valid = get_values(valid_df)
  183. test = get_values(test_df)
  184. return train, valid, test, train_df, valid_df, test_df
  185. def get_feature_model(params):
  186. embedding_dims = params['embedding_dims']
  187. max_features = 8001
  188. model = Sequential()
  189. model.add(Embedding(
  190. max_features,
  191. embedding_dims,
  192. input_length=MAXLEN))
  193. model.add(SpatialDropout1D(0.4))
  194. for i in range(params['nb_conv']):
  195. model.add(Conv1D(
  196. activation="relu",
  197. padding="valid",
  198. strides=1,
  199. filters=params['nb_filter'],
  200. kernel_size=params['filter_length']))
  201. model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length']))
  202. model.add(Flatten())
  203. return model
  204. def merge_outputs(outputs, name):
  205. if len(outputs) == 1:
  206. return outputs[0]
  207. ## return merge(outputs, mode='concat', name=name, concat_axis=1)
  208. return Concatenate(axis=1, name=name)(outputs)
  209. def merge_nets(nets, name):
  210. if len(nets) == 1:
  211. return nets[0]
  212. ## return merge(nets, mode='sum', name=name)
  213. return Add(name=name)(nets)
  214. def get_node_name(go_id, unique=False):
  215. name = go_id.split(':')[1]
  216. if not unique:
  217. return name
  218. if name not in node_names:
  219. node_names.add(name)
  220. return name
  221. i = 1
  222. while (name + '_' + str(i)) in node_names:
  223. i += 1
  224. name = name + '_' + str(i)
  225. node_names.add(name)
  226. return name
  227. def get_function_node(name, inputs):
  228. output_name = name + '_out'
  229. # net = Dense(256, name=name, activation='relu')(inputs)
  230. output = Dense(1, name=output_name, activation='sigmoid')(inputs)
  231. return output, output
  232. def get_generator(params, n_classes):
  233. inputs = Input(shape=(MAXLEN,), dtype='float32', name='input1')
  234. #feature_model = get_feature_model(params)(inputs)
  235. net0 = Dense(150, activation='relu')(inputs)
  236. net0 = Dense(150, activation='relu')(net0)
  237. #net0 = Dense(50, activation='relu')(net0)
  238. net = Dense(70, activation = 'relu')(net0)
  239. output = Dense(n_classes, activation='sigmoid')(net)
  240. model = Model(inputs=inputs, outputs=output)
  241. return model
  242. def get_discriminator(params, n_classes, dropout_rate=0.5):
  243. inputs = Input(shape=(n_classes, ))
  244. inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2')
  245. x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2)
  246. x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', activation ='relu', strides=1)(x2)
  247. x2 = Lambda(lambda x: K.squeeze(x, 2))(x2)
  248. #for i in range(params['nb_conv']):
  249. # x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2)
  250. #x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2)
  251. #x2 = Flatten()(x2)
  252. size = 40
  253. x = inputs
  254. x = Dropout(dropout_rate)(x)
  255. x = Dense(size)(x)
  256. x = BatchNormalization()(x)
  257. x = Activation('relu')(x)
  258. size = 40
  259. x2 = Dropout(dropout_rate)(x2)
  260. x2 = Dense(size)(x2)
  261. x2 = BatchNormalization()(x2)
  262. x2 = Activation('relu')(x2)
  263. x = Concatenate(axis =1 , name = 'merged2')([x, x2])
  264. layer_sizes = [80, 40,30]
  265. for size in layer_sizes:
  266. x = Dropout(dropout_rate)(x)
  267. x = Dense(size)(x)
  268. x = BatchNormalization()(x)
  269. x = Activation('relu')(x)
  270. outputs = Dense(1)(x)
  271. model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator')
  272. return model
  273. def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10):
  274. generator = get_generator(params, nb_classes)
  275. discriminator = get_discriminator(params, nb_classes)
  276. generator_model, discriminator_model = \
  277. WGAN_wrapper(generator=generator,
  278. discriminator=discriminator,
  279. generator_input_shape=(MAXLEN,),
  280. discriminator_input_shape=(nb_classes,),
  281. discriminator_input_shape2 = (MAXLEN, ),
  282. batch_size=batch_size,
  283. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
  284. logging.info('Compilation finished')
  285. return generator_model, discriminator_model
  286. def train_wgan(generator_model, discriminator_model, batch_size, epochs,
  287. x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path,
  288. TRAINING_RATIO=10, N_WARM_UP=0):
  289. BATCH_SIZE = batch_size
  290. N_EPOCH = epochs
  291. positive_y = np.ones((batch_size, 1), dtype=np.float32)
  292. zero_y = positive_y * 0
  293. negative_y = -positive_y
  294. positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32)
  295. dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
  296. positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 )
  297. positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 )
  298. #positive_enable_train = np.ones((1, batch_size),dtype = np.float32 )
  299. #positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 )
  300. best_validation_loss = None
  301. for epoch in range(N_EPOCH):
  302. # np.random.shuffle(X_train)
  303. print("Epoch: ", epoch)
  304. print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE))
  305. discriminator_loss = []
  306. generator_loss = []
  307. minibatches_size = BATCH_SIZE * TRAINING_RATIO
  308. shuffled_indexes = np.random.permutation(x_train.shape[0])
  309. shuffled_indexes_2 = np.random.permutation(x_train.shape[0])
  310. for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))):
  311. batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size]
  312. batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size]
  313. x = x_train[batch_indexes]
  314. y = y_train[batch_indexes]
  315. y_2 = y_train[batch_indexes_2]
  316. x_2 = x_train[batch_indexes_2]
  317. if epoch < N_WARM_UP:
  318. for j in range(TRAINING_RATIO):
  319. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  320. y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  321. generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y]))
  322. else:
  323. for j in range(TRAINING_RATIO):
  324. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  325. y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  326. x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  327. # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
  328. noise = x_batch
  329. #print(sum(y_batch_2))
  330. discriminator_loss.append(discriminator_model.train_on_batch(
  331. [y_batch_2, noise, x_batch_2 ],
  332. [positive_y, negative_y, dummy_y]))
  333. generator_loss.append(generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y]))
  334. # Still needs some code to display losses from the generator and discriminator, progress bars, etc.
  335. predicted_y_train, _ = generator_model.predict([x_train , positive_full_enable_train], batch_size=BATCH_SIZE)
  336. predicted_y_val, _ = generator_model.predict([ x_val , positive_full_enable_val ], batch_size=BATCH_SIZE)
  337. #print(sum(sum(positive_full_enable_train)))
  338. #print(predicted_y_train)
  339. train_loss = log_loss(y_train, predicted_y_train)
  340. val_loss = log_loss(y_val, predicted_y_val)
  341. print("train loss: {:.4f}, validation loss: {:.4f}, discriminator loss: {:.4f}".format(
  342. train_loss, val_loss,
  343. (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0]))
  344. if best_validation_loss is None or best_validation_loss > val_loss:
  345. print('\nEpoch %05d: improved from %0.5f,'
  346. ' saving model to %s and %s'
  347. % (epoch + 1, val_loss, generator_model_path, discriminator_model_path))
  348. best_validation_loss = val_loss
  349. generator_model.save(generator_model_path, overwrite=True)
  350. discriminator_model.save(discriminator_model_path, overwrite=True)
  351. def model(params, batch_size=20, nb_epoch=40, is_train=True):
  352. # set parameters:
  353. #nb_classes = len(functions)
  354. start_time = time.time()
  355. logging.info("Loading Data")
  356. ##
  357. #train, val, test, train_df, valid_df, test_df = load_data()
  358. #train_df = pd.concat([train_df, valid_df])
  359. #test_gos = test_df['gos'].values
  360. #train_data, train_labels = train
  361. #val_data, val_labels = val
  362. #test_data, test_labels = test
  363. ##
  364. train_data, train_labels, test_data, test_labels, val_data, val_labels = load_data2()
  365. nb_classes = train_labels.shape[1]
  366. logging.info("Data loaded in %d sec" % (time.time() - start_time))
  367. logging.info("Training data size: %d" % len(train_data))
  368. logging.info("Validation data size: %d" % len(val_data))
  369. logging.info("Test data size: %d" % len(test_data))
  370. generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5'
  371. discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5'
  372. logging.info('Starting training the model')
  373. train_generator = DataGenerator(batch_size, nb_classes)
  374. train_generator.fit(train_data, train_labels)
  375. valid_generator = DataGenerator(batch_size, nb_classes)
  376. valid_generator.fit(val_data, val_labels)
  377. test_generator = DataGenerator(batch_size, nb_classes)
  378. test_generator.fit(test_data, test_labels)
  379. if is_train:
  380. generator_model, discriminator_model = get_model(params, nb_classes, batch_size)
  381. train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch,
  382. x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels,
  383. generator_model_path=generator_model_path,
  384. discriminator_model_path=discriminator_model_path)
  385. logging.info('Loading best model')
  386. model = load_model(generator_model_path,
  387. custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new,
  388. 'wasserstein_loss': wasserstein_loss})
  389. logging.info('Predicting')
  390. preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0]
  391. # incon = 0
  392. # for i in xrange(len(test_data)):
  393. # for j in xrange(len(functions)):
  394. # childs = set(go[functions[j]]['children']).intersection(func_set)
  395. # ok = True`
  396. # for n_id in childs:
  397. # if preds[i, j] < preds[i, go_indexes[n_id]]:
  398. # preds[i, j] = preds[i, go_indexes[n_id]]
  399. # ok = False
  400. # if not ok:
  401. # incon += 1
  402. logging.info('Computing performance')
  403. f, p, r, t, preds_max = compute_performance(preds, test_labels) #, test_gos)
  404. roc_auc = compute_roc(preds, test_labels)
  405. mcc = compute_mcc(preds_max, test_labels)
  406. aupr , _ = compute_aupr(preds, test_labels)
  407. m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T)
  408. logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f))
  409. logging.info('ROC AUC: \t %f ' % (roc_auc, ))
  410. logging.info('MCC: \t %f ' % (mcc, ))
  411. logging.info('AUPR: \t %f ' % (aupr, ))
  412. logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) )
  413. logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) )
  414. function_centric_performance(functions, preds.T, test_labels.T, train_labels.T)
  415. def load_prot_ipro():
  416. proteins = list()
  417. ipros = list()
  418. with open(DATA_ROOT + 'swissprot_ipro.tab') as f:
  419. for line in f:
  420. it = line.strip().split('\t')
  421. if len(it) != 3:
  422. continue
  423. prot = it[1]
  424. iprs = it[2].split(';')
  425. proteins.append(prot)
  426. ipros.append(iprs)
  427. return pd.DataFrame({'proteins': proteins, 'ipros': ipros})
  428. def performanc_by_interpro():
  429. pred_df = pd.read_pickle(DATA_ROOT + 'test-' + FUNCTION + '-preds.pkl')
  430. ipro_df = load_prot_ipro()
  431. df = pred_df.merge(ipro_df, on='proteins', how='left')
  432. ipro = get_ipro()
  433. def reshape(values):
  434. values = np.hstack(values).reshape(
  435. len(values), len(values[0]))
  436. return values
  437. for ipro_id in ipro:
  438. if len(ipro[ipro_id]['parents']) > 0:
  439. continue
  440. labels = list()
  441. predictions = list()
  442. gos = list()
  443. for i, row in df.iterrows():
  444. if not isinstance(row['ipros'], list):
  445. continue
  446. if ipro_id in row['ipros']:
  447. labels.append(row['labels'])
  448. predictions.append(row['predictions'])
  449. gos.append(row['gos'])
  450. pr = 0
  451. rc = 0
  452. total = 0
  453. p_total = 0
  454. for i in range(len(labels)):
  455. tp = np.sum(labels[i] * predictions[i])
  456. fp = np.sum(predictions[i]) - tp
  457. fn = np.sum(labels[i]) - tp
  458. all_gos = set()
  459. for go_id in gos[i]:
  460. if go_id in all_functions:
  461. all_gos |= get_anchestors(go, go_id)
  462. all_gos.discard(GO_ID)
  463. all_gos -= func_set
  464. fn += len(all_gos)
  465. if tp == 0 and fp == 0 and fn == 0:
  466. continue
  467. total += 1
  468. if tp != 0:
  469. p_total += 1
  470. precision = tp / (1.0 * (tp + fp))
  471. recall = tp / (1.0 * (tp + fn))
  472. pr += precision
  473. rc += recall
  474. if total > 0 and p_total > 0:
  475. rc /= total
  476. pr /= p_total
  477. if pr + rc > 0:
  478. f = 2 * pr * rc / (pr + rc)
  479. logging.info('%s\t%d\t%f\t%f\t%f' % (
  480. ipro_id, len(labels), f, pr, rc))
  481. def function_centric_performance(functions, preds, labels, labels_train):
  482. results = []
  483. preds = np.round(preds, 2)
  484. for i in range(preds.shape[0]):
  485. f_max = 0
  486. p_max = 0
  487. r_max = 0
  488. for t in range(1, 100):
  489. threshold = t / 100.0
  490. predictions = (preds[i, :] > threshold).astype(np.int32)
  491. tp = np.sum(predictions * labels[i, :])
  492. fp = np.sum(predictions) - tp
  493. fn = np.sum(labels[i, :]) - tp
  494. if tp > 0:
  495. precision = tp / (1.0 * (tp + fp))
  496. recall = tp / (1.0 * (tp + fn))
  497. f = 2 * precision * recall / (precision + recall)
  498. else:
  499. if fp == 0 and fn == 0:
  500. precision = 1
  501. recall = 1
  502. f = 1
  503. else:
  504. precision = 0
  505. recall = 0
  506. f = 0
  507. if f_max < f:
  508. f_max = f
  509. p_max = precision
  510. r_max = recall
  511. num_prots_train = np.sum(labels_train[i, :])
  512. height = get_height(go, functions[i])
  513. results.append([functions[i], num_prots_train, height, f_max, p_max, r_max])
  514. results = pd.DataFrame(results)
  515. results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False)
  516. def function_centric_performance_backup(functions, preds, labels, labels_train):
  517. results = []
  518. preds = np.round(preds, 2)
  519. for i in range(len(functions)):
  520. f_max = 0
  521. p_max = 0
  522. r_max = 0
  523. x = list()
  524. y = list()
  525. total = 0
  526. for t in range(1, 100):
  527. threshold = t / 100.0
  528. predictions = (preds[i, :] > threshold).astype(np.int32)
  529. tp = np.sum(predictions * labels[i, :])
  530. fp = np.sum(predictions) - tp
  531. fn = np.sum(labels[i, :]) - tp
  532. if tp >0:
  533. sn = tp / (1.0 * np.sum(labels[i, :]))
  534. sp = np.sum((predictions ^ 1) * (labels[i, :] ^ 1))
  535. sp /= 1.0 * np.sum(labels[i, :] ^ 1)
  536. fpr = 1 - sp
  537. x.append(fpr)
  538. y.append(sn)
  539. precision = tp / (1.0 * (tp + fp))
  540. recall = tp / (1.0 * (tp + fn))
  541. f = 2 * precision * recall / (precision + recall)
  542. total +=1
  543. if f_max < f:
  544. f_max = f
  545. p_max = precision
  546. r_max = recall
  547. num_prots = np.sum(labels[i, :])
  548. num_prots_train = np.sum(labels_train[i,:])
  549. if total >1 :
  550. roc_auc = auc(x, y)
  551. else:
  552. roc_auc =0
  553. height = get_height(go , functions[i])
  554. results.append([functions[i], f_max, p_max, r_max, num_prots, num_prots_train, height,roc_auc])
  555. results = pd.DataFrame(results)
  556. #results.to_csv('new_results.txt' , sep='\t' , index = False)
  557. results.to_csv('Con_GodGanSeq_results_'+FUNCTION +'.txt', sep='\t', index=False)
  558. #results = np.array(results)
  559. #p_mean = (np.sum(results[:,2])) / len(functions)
  560. #r_mean = (np.sum(results[:,3])) / len(functions)
  561. #f_mean = (2*p_mean*r_mean)/(p_mean+r_mean)
  562. #roc_auc_mean = (np.sum(results[:,7])) / len(functions)
  563. #print('Function centric performance (macro) ' '%f %f %f %f' % (f_mean, p_mean, r_mean, roc_auc_mean))
  564. def micro_macro_function_centric_f1_backup(preds, labels):
  565. preds = np.round(preds, 2)
  566. m_f1_max = 0
  567. M_f1_max = 0
  568. for t in range(1, 100):
  569. threshold = t / 100.0
  570. predictions = (preds > threshold).astype(np.int32)
  571. m_tp = 0
  572. m_fp = 0
  573. m_fn = 0
  574. M_pr = 0
  575. M_rc = 0
  576. total = 0
  577. p_total = 0
  578. for i in range(len(preds)):
  579. tp = np.sum(predictions[i, :] * labels[i, :])
  580. fp = np.sum(predictions[i, :]) - tp
  581. fn = np.sum(labels[i, :]) - tp
  582. if tp == 0 and fp == 0 and fn == 0:
  583. continue
  584. total += 1
  585. if tp > 0:
  586. pr = tp / (1.0 * (tp + fp))
  587. rc = tp / (1.0 * (tp + fn))
  588. m_tp += tp
  589. m_fp += fp
  590. m_fn += fn
  591. M_pr += pr
  592. M_rc += rc
  593. p_total += 1
  594. if p_total == 0:
  595. continue
  596. if total > 0:
  597. m_tp /= total
  598. m_fn /= total
  599. m_fp /= total
  600. m_pr = m_tp / (1.0 * (m_tp + m_fp))
  601. m_rc = m_tp / (1.0 * (m_tp + m_fn))
  602. M_pr /= p_total
  603. M_rc /= total
  604. m_f1 = 2 * m_pr * m_rc / (m_pr + m_rc)
  605. M_f1 = 2 * M_pr * M_rc / (M_pr + M_rc)
  606. if m_f1 > m_f1_max:
  607. m_f1_max = m_f1
  608. m_pr_max = m_pr
  609. m_rc_max = m_rc
  610. if M_f1 > M_f1_max:
  611. M_f1_max = M_f1
  612. M_pr_max = M_pr
  613. M_rc_max = M_rc
  614. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  615. def micro_macro_function_centric_f1(preds, labels):
  616. preds = np.round(preds, 2)
  617. m_f1_max = 0
  618. M_f1_max = 0
  619. for t in range(1, 200):
  620. threshold = t / 200.0
  621. predictions = (preds > threshold).astype(np.int32)
  622. m_tp = 0
  623. m_fp = 0
  624. m_fn = 0
  625. M_pr = 0
  626. M_rc = 0
  627. for i in range(preds.shape[0]):
  628. tp = np.sum(predictions[i, :] * labels[i, :])
  629. fp = np.sum(predictions[i, :]) - tp
  630. fn = np.sum(labels[i, :]) - tp
  631. m_tp += tp
  632. m_fp += fp
  633. m_fn += fn
  634. if tp > 0:
  635. pr = 1.0 * tp / (1.0 * (tp + fp))
  636. rc = 1.0 * tp / (1.0 * (tp + fn))
  637. else:
  638. if fp == 0 and fn == 0:
  639. pr = 1
  640. rc = 1
  641. else:
  642. pr = 0
  643. rc = 0
  644. M_pr += pr
  645. M_rc += rc
  646. if m_tp > 0:
  647. m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp))
  648. m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn))
  649. m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc)
  650. else:
  651. if m_fp == 0 and m_fn == 0:
  652. m_pr = 1
  653. m_rc = 1
  654. m_f1 = 1
  655. else:
  656. m_pr = 0
  657. m_rc = 0
  658. m_f1 = 0
  659. M_pr /= preds.shape[0]
  660. M_rc /= preds.shape[0]
  661. if M_pr == 0 and M_rc == 0:
  662. M_f1 = 0
  663. else:
  664. M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc)
  665. if m_f1 > m_f1_max:
  666. m_f1_max = m_f1
  667. m_pr_max = m_pr
  668. m_rc_max = m_rc
  669. if M_f1 > M_f1_max:
  670. M_f1_max = M_f1
  671. M_pr_max = M_pr
  672. M_rc_max = M_rc
  673. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  674. def compute_roc(preds, labels):
  675. # Compute ROC curve and ROC area for each class
  676. fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
  677. roc_auc = auc(fpr, tpr)
  678. return roc_auc
  679. def compute_aupr(preds, labels):
  680. # Compute ROC curve and ROC area for each class
  681. pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten())
  682. pr_auc = auc(rc, pr)
  683. #pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' )
  684. M_pr_auc = 0
  685. return pr_auc, M_pr_auc
  686. def compute_mcc(preds, labels):
  687. # Compute ROC curve and ROC area for each class
  688. mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
  689. return mcc
  690. def compute_performance(preds, labels): #, gos):
  691. preds = np.round(preds, 2)
  692. f_max = 0
  693. p_max = 0
  694. r_max = 0
  695. t_max = 0
  696. for t in range(1, 100):
  697. threshold = t / 100.0
  698. predictions = (preds > threshold).astype(np.int32)
  699. total = 0
  700. f = 0.0
  701. p = 0.0
  702. r = 0.0
  703. p_total = 0
  704. for i in range(labels.shape[0]):
  705. tp = np.sum(predictions[i, :] * labels[i, :])
  706. fp = np.sum(predictions[i, :]) - tp
  707. fn = np.sum(labels[i, :]) - tp
  708. all_gos = set()
  709. #for go_id in gos[i]:
  710. # if go_id in all_functions:
  711. # all_gos |= get_anchestors(go, go_id)
  712. #all_gos.discard(GO_ID)
  713. #all_gos -= func_set
  714. #fn += len(all_gos)
  715. if tp == 0 and fp == 0 and fn == 0:
  716. continue
  717. total += 1
  718. if tp != 0:
  719. p_total += 1
  720. precision = tp / (1.0 * (tp + fp))
  721. recall = tp / (1.0 * (tp + fn))
  722. p += precision
  723. r += recall
  724. if p_total == 0:
  725. continue
  726. r /= total
  727. p /= p_total
  728. if p + r > 0:
  729. f = 2 * p * r / (p + r)
  730. if f_max < f:
  731. f_max = f
  732. p_max = p
  733. r_max = r
  734. t_max = threshold
  735. predictions_max = predictions
  736. return f_max, p_max, r_max, t_max, predictions_max
  737. def get_gos(pred):
  738. mdist = 1.0
  739. mgos = None
  740. for i in range(len(labels_gos)):
  741. labels, gos = labels_gos[i]
  742. dist = distance.cosine(pred, labels)
  743. if mdist > dist:
  744. mdist = dist
  745. mgos = gos
  746. return mgos
  747. def compute_similarity_performance(train_df, test_df, preds):
  748. logging.info("Computing similarity performance")
  749. logging.info("Training data size %d" % len(train_df))
  750. train_labels = train_df['labels'].values
  751. train_gos = train_df['gos'].values
  752. global labels_gos
  753. labels_gos = zip(train_labels, train_gos)
  754. p = Pool(64)
  755. pred_gos = p.map(get_gos, preds)
  756. total = 0
  757. p = 0.0
  758. r = 0.0
  759. f = 0.0
  760. test_gos = test_df['gos'].values
  761. for gos, tgos in zip(pred_gos, test_gos):
  762. preds = set()
  763. test = set()
  764. for go_id in gos:
  765. if go_id in all_functions:
  766. preds |= get_anchestors(go, go_id)
  767. for go_id in tgos:
  768. if go_id in all_functions:
  769. test |= get_anchestors(go, go_id)
  770. tp = len(preds.intersection(test))
  771. fp = len(preds - test)
  772. fn = len(test - preds)
  773. if tp == 0 and fp == 0 and fn == 0:
  774. continue
  775. total += 1
  776. if tp != 0:
  777. precision = tp / (1.0 * (tp + fp))
  778. recall = tp / (1.0 * (tp + fn))
  779. p += precision
  780. r += recall
  781. f += 2 * precision * recall / (precision + recall)
  782. return f / total, p / total, r / total
  783. def print_report(report, go_id):
  784. with open(DATA_ROOT + 'reports.txt', 'a') as f:
  785. f.write('Classification report for ' + go_id + '\n')
  786. f.write(report + '\n')
  787. if __name__ == '__main__':
  788. main()