You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

conditional_new_nn_hirearchical_seq.py 30KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901
  1. #!/usr/bin/env python
  2. """
  3. """
  4. from __future__ import division
  5. import logging
  6. import sys
  7. import time
  8. from collections import deque
  9. from multiprocessing import Pool
  10. import click as ck
  11. import numpy as np
  12. import pandas as pd
  13. import tensorflow as tf
  14. from keras import backend as K
  15. from keras.callbacks import EarlyStopping, ModelCheckpoint
  16. from keras.layers import (
  17. Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D,
  18. Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout)
  19. from keras.losses import binary_crossentropy
  20. from keras.models import Sequential, Model, load_model
  21. from keras.preprocessing import sequence
  22. from scipy.spatial import distance
  23. from sklearn.metrics import log_loss
  24. from sklearn.metrics import roc_curve, auc, matthews_corrcoef
  25. from keras.layers import Lambda
  26. from sklearn.metrics import precision_recall_curve
  27. from utils import (
  28. get_gene_ontology,
  29. get_go_set,
  30. get_anchestors,
  31. get_parents,
  32. DataGenerator,
  33. FUNC_DICT,
  34. get_height,
  35. get_ipro)
  36. from conditional_wgan_wrapper_post import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new
  37. config = tf.ConfigProto()
  38. config.gpu_options.allow_growth = True
  39. sess = tf.Session(config=config)
  40. K.set_session(sess)
  41. logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
  42. sys.setrecursionlimit(100000)
  43. DATA_ROOT = 'data/swiss/'
  44. MAXLEN = 1000
  45. REPLEN = 256
  46. ind = 0
  47. @ck.command()
  48. @ck.option(
  49. '--function',
  50. default='bp',
  51. help='Ontology id (mf, bp, cc)')
  52. @ck.option(
  53. '--device',
  54. default='gpu:0',
  55. help='GPU or CPU device id')
  56. @ck.option(
  57. '--org',
  58. default= None,
  59. help='Organism id for filtering test set')
  60. @ck.option('--train',default = True, is_flag=True)
  61. @ck.option('--param', default=0, help='Param index 0-7')
  62. def main(function, device, org, train, param):
  63. global FUNCTION
  64. FUNCTION = function
  65. global GO_ID
  66. GO_ID = FUNC_DICT[FUNCTION]
  67. global go
  68. go = get_gene_ontology('go.obo')
  69. global ORG
  70. ORG = org
  71. func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl')
  72. global functions
  73. functions = func_df['functions'].values
  74. global func_set
  75. func_set = set(functions)
  76. global all_functions
  77. all_functions = get_go_set(go, GO_ID)
  78. logging.info('Functions: %s %d' % (FUNCTION, len(functions)))
  79. if ORG is not None:
  80. logging.info('Organism %s' % ORG)
  81. global go_indexes
  82. go_indexes = dict()
  83. for ind, go_id in enumerate(functions):
  84. go_indexes[go_id] = ind
  85. global node_names
  86. node_names = set()
  87. with tf.device('/' + device):
  88. params = {
  89. 'fc_output': 1024,
  90. 'learning_rate': 0.001,
  91. 'embedding_dims': 128,
  92. 'embedding_dropout': 0.2,
  93. 'nb_conv': 1,
  94. 'nb_dense': 1,
  95. 'filter_length': 128,
  96. 'nb_filter': 32,
  97. 'pool_length': 64,
  98. 'stride': 32
  99. }
  100. model(params, is_train=train)
  101. def load_data():
  102. df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
  103. n = len(df)
  104. index = df.index.values
  105. valid_n = int(n * 0.8)
  106. train_df = df.loc[index[:valid_n]]
  107. valid_df = df.loc[index[valid_n:]]
  108. test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
  109. print( test_df['orgs'] )
  110. if ORG is not None:
  111. logging.info('Unfiltered test size: %d' % len(test_df))
  112. test_df = test_df[test_df['orgs'] == ORG]
  113. logging.info('Filtered test size: %d' % len(test_df))
  114. # Filter by type
  115. # org_df = pd.read_pickle('data/prokaryotes.pkl')
  116. # orgs = org_df['orgs']
  117. # test_df = test_df[test_df['orgs'].isin(orgs)]
  118. def reshape(values):
  119. values = np.hstack(values).reshape(
  120. len(values), len(values[0]))
  121. return values
  122. def normalize_minmax(values):
  123. mn = np.min(values)
  124. mx = np.max(values)
  125. if mx - mn != 0.0:
  126. return (values - mn) / (mx - mn)
  127. return values - mn
  128. def get_values(data_frame):
  129. print(data_frame['labels'].values.shape)
  130. labels = reshape(data_frame['labels'].values)
  131. ngrams = sequence.pad_sequences(
  132. data_frame['ngrams'].values, maxlen=MAXLEN)
  133. ngrams = reshape(ngrams)
  134. rep = reshape(data_frame['embeddings'].values)
  135. data = ngrams
  136. return data, labels
  137. train = get_values(train_df)
  138. valid = get_values(valid_df)
  139. test = get_values(test_df)
  140. return train, valid, test, train_df, valid_df, test_df
  141. def get_feature_model(params):
  142. embedding_dims = params['embedding_dims']
  143. max_features = 8001
  144. model = Sequential()
  145. model.add(Embedding(
  146. max_features,
  147. embedding_dims,
  148. input_length=MAXLEN))
  149. model.add(SpatialDropout1D(0.4))
  150. for i in range(params['nb_conv']):
  151. model.add(Conv1D(
  152. activation="relu",
  153. padding="valid",
  154. strides=1,
  155. filters=params['nb_filter'],
  156. kernel_size=params['filter_length']))
  157. model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length']))
  158. model.add(Flatten())
  159. return model
  160. def merge_outputs(outputs, name):
  161. if len(outputs) == 1:
  162. return outputs[0]
  163. ## return merge(outputs, mode='concat', name=name, concat_axis=1)
  164. return Concatenate(axis=1, name=name)(outputs)
  165. def merge_nets(nets, name):
  166. if len(nets) == 1:
  167. return nets[0]
  168. ## return merge(nets, mode='sum', name=name)
  169. return Add(name=name)(nets)
  170. def get_node_name(go_id, unique=False):
  171. name = go_id.split(':')[1]
  172. if not unique:
  173. return name
  174. if name not in node_names:
  175. node_names.add(name)
  176. return name
  177. i = 1
  178. while (name + '_' + str(i)) in node_names:
  179. i += 1
  180. name = name + '_' + str(i)
  181. node_names.add(name)
  182. return name
  183. def get_layers(inputs):
  184. q = deque()
  185. layers = {}
  186. name = get_node_name(GO_ID)
  187. layers[GO_ID] = {'net': inputs}
  188. for node_id in go[GO_ID]['children']:
  189. if node_id in func_set:
  190. q.append((node_id, inputs))
  191. while len(q) > 0:
  192. node_id, net = q.popleft()
  193. parent_nets = [inputs]
  194. # for p_id in get_parents(go, node_id):
  195. # if p_id in func_set:
  196. # parent_nets.append(layers[p_id]['net'])
  197. # if len(parent_nets) > 1:
  198. # name = get_node_name(node_id) + '_parents'
  199. # net = merge(
  200. # parent_nets, mode='concat', concat_axis=1, name=name)
  201. name = get_node_name(node_id)
  202. net, output = get_function_node(name, inputs)
  203. if node_id not in layers:
  204. layers[node_id] = {'net': net, 'output': output}
  205. for n_id in go[node_id]['children']:
  206. if n_id in func_set and n_id not in layers:
  207. ok = True
  208. for p_id in get_parents(go, n_id):
  209. if p_id in func_set and p_id not in layers:
  210. ok = False
  211. if ok:
  212. q.append((n_id, net))
  213. for node_id in functions:
  214. childs = set(go[node_id]['children']).intersection(func_set)
  215. if len(childs) > 0:
  216. outputs = [layers[node_id]['output']]
  217. for ch_id in childs:
  218. outputs.append(layers[ch_id]['output'])
  219. name = get_node_name(node_id) + '_max'
  220. ## layers[node_id]['output'] = merge(
  221. ## outputs, mode='max', name=name)
  222. layers[node_id]['output'] = Maximum(name=name)(outputs)
  223. return layers
  224. def get_function_node(name, inputs):
  225. output_name = name + '_out'
  226. # net = Dense(256, name=name, activation='relu')(inputs)
  227. output = Dense(1, name=output_name, activation='sigmoid')(inputs)
  228. return output, output
  229. def get_generator(params, n_classes):
  230. inputs = Input(shape=(MAXLEN,), dtype='int32', name='input1')
  231. feature_model = get_feature_model(params)(inputs)
  232. net = Dense(300, activation='relu')(feature_model)
  233. net = BatchNormalization()(net)
  234. layers = get_layers(net)
  235. output_models = []
  236. for i in range(len(functions)):
  237. output_models.append(layers[functions[i]]['output'])
  238. net = Concatenate(axis=1)(output_models)
  239. output = Dense(n_classes, activation='sigmoid')(net)
  240. model = Model(inputs=inputs, outputs=output)
  241. return model
  242. def get_discriminator(params, n_classes, dropout_rate=0.5):
  243. inputs = Input(shape=(n_classes, ))
  244. inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2')
  245. x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2)
  246. x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', activation ='relu', strides=1)(x2)
  247. x2 = Lambda(lambda x: K.squeeze(x, 2))(x2)
  248. #for i in range(params['nb_conv']):
  249. # x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2)
  250. #x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2)
  251. #x2 = Flatten()(x2)
  252. size = 40
  253. x = inputs
  254. x = Dropout(dropout_rate)(x)
  255. x = Dense(size)(x)
  256. x = BatchNormalization()(x)
  257. x = Activation('relu')(x)
  258. size = 40
  259. x2 = Dropout(dropout_rate)(x2)
  260. x2 = Dense(size)(x2)
  261. x2 = BatchNormalization()(x2)
  262. x2 = Activation('relu')(x2)
  263. x = Concatenate(axis =1 , name = 'merged2')([x, x2])
  264. layer_sizes = [80, 40,30]
  265. for size in layer_sizes:
  266. x = Dropout(dropout_rate)(x)
  267. x = Dense(size)(x)
  268. x = BatchNormalization()(x)
  269. x = Activation('relu')(x)
  270. outputs = Dense(1)(x)
  271. model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator')
  272. return model
  273. def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10):
  274. generator = get_generator(params, nb_classes)
  275. discriminator = get_discriminator(params, nb_classes)
  276. generator_model, discriminator_model = \
  277. WGAN_wrapper(generator=generator,
  278. discriminator=discriminator,
  279. generator_input_shape=(MAXLEN,),
  280. discriminator_input_shape=(nb_classes,),
  281. discriminator_input_shape2 = (MAXLEN, ),
  282. batch_size=batch_size,
  283. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
  284. logging.info('Compilation finished')
  285. return generator_model, discriminator_model
  286. def train_wgan(generator_model, discriminator_model, batch_size, epochs,
  287. x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path,
  288. TRAINING_RATIO=10, N_WARM_UP=0):
  289. BATCH_SIZE = batch_size
  290. N_EPOCH = epochs
  291. positive_y = np.ones((batch_size, 1), dtype=np.float32)
  292. zero_y = positive_y * 0
  293. negative_y = -positive_y
  294. positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32)
  295. dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
  296. positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 )
  297. positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 )
  298. #positive_enable_train = np.ones((1, batch_size),dtype = np.float32 )
  299. #positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 )
  300. best_validation_loss = None
  301. for epoch in range(N_EPOCH):
  302. # np.random.shuffle(X_train)
  303. print("Epoch: ", epoch)
  304. print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE))
  305. discriminator_loss = []
  306. generator_loss = []
  307. minibatches_size = BATCH_SIZE * TRAINING_RATIO
  308. shuffled_indexes = np.random.permutation(x_train.shape[0])
  309. shuffled_indexes_2 = np.random.permutation(x_train.shape[0])
  310. for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))):
  311. batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size]
  312. batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size]
  313. x = x_train[batch_indexes]
  314. y = y_train[batch_indexes]
  315. y_2 = y_train[batch_indexes_2]
  316. x_2 = x_train[batch_indexes_2]
  317. if epoch < N_WARM_UP:
  318. for j in range(TRAINING_RATIO):
  319. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  320. y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  321. generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y]))
  322. else:
  323. for j in range(TRAINING_RATIO):
  324. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  325. y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  326. x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  327. # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
  328. noise = x_batch
  329. #print(sum(y_batch_2))
  330. discriminator_loss.append(discriminator_model.train_on_batch(
  331. [y_batch_2, noise, x_batch_2 ],
  332. [positive_y, negative_y, dummy_y]))
  333. generator_loss.append(generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y]))
  334. # Still needs some code to display losses from the generator and discriminator, progress bars, etc.
  335. predicted_y_train, _ = generator_model.predict([x_train , positive_full_enable_train], batch_size=BATCH_SIZE)
  336. predicted_y_val, _ = generator_model.predict([ x_val , positive_full_enable_val ], batch_size=BATCH_SIZE)
  337. #print(sum(sum(positive_full_enable_train)))
  338. #print(predicted_y_train)
  339. train_loss = log_loss(y_train, predicted_y_train)
  340. val_loss = log_loss(y_val, predicted_y_val)
  341. print("train loss: {:.4f}, validation loss: {:.4f}, discriminator loss: {:.4f}".format(
  342. train_loss, val_loss,
  343. (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0]))
  344. if best_validation_loss is None or best_validation_loss > val_loss:
  345. print('\nEpoch %05d: improved from %0.5f,'
  346. ' saving model to %s and %s'
  347. % (epoch + 1, val_loss, generator_model_path, discriminator_model_path))
  348. best_validation_loss = val_loss
  349. generator_model.save(generator_model_path, overwrite=True)
  350. discriminator_model.save(discriminator_model_path, overwrite=True)
  351. def model(params, batch_size=20, nb_epoch=40, is_train=True):
  352. # set parameters:
  353. nb_classes = len(functions)
  354. start_time = time.time()
  355. logging.info("Loading Data")
  356. train, val, test, train_df, valid_df, test_df = load_data()
  357. train_df = pd.concat([train_df, valid_df])
  358. test_gos = test_df['gos'].values
  359. train_data, train_labels = train
  360. val_data, val_labels = val
  361. test_data, test_labels = test
  362. logging.info("Data loaded in %d sec" % (time.time() - start_time))
  363. logging.info("Training data size: %d" % len(train_data))
  364. logging.info("Validation data size: %d" % len(val_data))
  365. logging.info("Test data size: %d" % len(test_data))
  366. generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5'
  367. discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5'
  368. logging.info('Starting training the model')
  369. train_generator = DataGenerator(batch_size, nb_classes)
  370. train_generator.fit(train_data, train_labels)
  371. valid_generator = DataGenerator(batch_size, nb_classes)
  372. valid_generator.fit(val_data, val_labels)
  373. test_generator = DataGenerator(batch_size, nb_classes)
  374. test_generator.fit(test_data, test_labels)
  375. if is_train:
  376. generator_model, discriminator_model = get_model(params, nb_classes, batch_size)
  377. train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch,
  378. x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels,
  379. generator_model_path=generator_model_path,
  380. discriminator_model_path=discriminator_model_path)
  381. logging.info('Loading best model')
  382. model = load_model(generator_model_path,
  383. custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new,
  384. 'wasserstein_loss': wasserstein_loss})
  385. logging.info('Predicting')
  386. preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0]
  387. logging.info('Computing performance')
  388. f, p, r, t, preds_max = compute_performance(preds, test_labels) #, test_gos)
  389. roc_auc = compute_roc(preds, test_labels)
  390. mcc = compute_mcc(preds_max, test_labels)
  391. aupr , _ = compute_aupr(preds, test_labels)
  392. m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T)
  393. logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f))
  394. logging.info('ROC AUC: \t %f ' % (roc_auc, ))
  395. logging.info('MCC: \t %f ' % (mcc, ))
  396. logging.info('AUPR: \t %f ' % (aupr, ))
  397. logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) )
  398. logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) )
  399. function_centric_performance(functions, preds.T, test_labels.T, train_labels.T)
  400. def load_prot_ipro():
  401. proteins = list()
  402. ipros = list()
  403. with open(DATA_ROOT + 'swissprot_ipro.tab') as f:
  404. for line in f:
  405. it = line.strip().split('\t')
  406. if len(it) != 3:
  407. continue
  408. prot = it[1]
  409. iprs = it[2].split(';')
  410. proteins.append(prot)
  411. ipros.append(iprs)
  412. return pd.DataFrame({'proteins': proteins, 'ipros': ipros})
  413. def performanc_by_interpro():
  414. pred_df = pd.read_pickle(DATA_ROOT + 'test-' + FUNCTION + '-preds.pkl')
  415. ipro_df = load_prot_ipro()
  416. df = pred_df.merge(ipro_df, on='proteins', how='left')
  417. ipro = get_ipro()
  418. def reshape(values):
  419. values = np.hstack(values).reshape(
  420. len(values), len(values[0]))
  421. return values
  422. for ipro_id in ipro:
  423. if len(ipro[ipro_id]['parents']) > 0:
  424. continue
  425. labels = list()
  426. predictions = list()
  427. gos = list()
  428. for i, row in df.iterrows():
  429. if not isinstance(row['ipros'], list):
  430. continue
  431. if ipro_id in row['ipros']:
  432. labels.append(row['labels'])
  433. predictions.append(row['predictions'])
  434. gos.append(row['gos'])
  435. pr = 0
  436. rc = 0
  437. total = 0
  438. p_total = 0
  439. for i in range(len(labels)):
  440. tp = np.sum(labels[i] * predictions[i])
  441. fp = np.sum(predictions[i]) - tp
  442. fn = np.sum(labels[i]) - tp
  443. all_gos = set()
  444. for go_id in gos[i]:
  445. if go_id in all_functions:
  446. all_gos |= get_anchestors(go, go_id)
  447. all_gos.discard(GO_ID)
  448. all_gos -= func_set
  449. fn += len(all_gos)
  450. if tp == 0 and fp == 0 and fn == 0:
  451. continue
  452. total += 1
  453. if tp != 0:
  454. p_total += 1
  455. precision = tp / (1.0 * (tp + fp))
  456. recall = tp / (1.0 * (tp + fn))
  457. pr += precision
  458. rc += recall
  459. if total > 0 and p_total > 0:
  460. rc /= total
  461. pr /= p_total
  462. if pr + rc > 0:
  463. f = 2 * pr * rc / (pr + rc)
  464. logging.info('%s\t%d\t%f\t%f\t%f' % (
  465. ipro_id, len(labels), f, pr, rc))
  466. def function_centric_performance(functions, preds, labels, labels_train):
  467. results = []
  468. preds = np.round(preds, 2)
  469. for i in range(preds.shape[0]):
  470. f_max = 0
  471. p_max = 0
  472. r_max = 0
  473. for t in range(1, 100):
  474. threshold = t / 100.0
  475. predictions = (preds[i, :] > threshold).astype(np.int32)
  476. tp = np.sum(predictions * labels[i, :])
  477. fp = np.sum(predictions) - tp
  478. fn = np.sum(labels[i, :]) - tp
  479. if tp > 0:
  480. precision = tp / (1.0 * (tp + fp))
  481. recall = tp / (1.0 * (tp + fn))
  482. f = 2 * precision * recall / (precision + recall)
  483. else:
  484. if fp == 0 and fn == 0:
  485. precision = 1
  486. recall = 1
  487. f = 1
  488. else:
  489. precision = 0
  490. recall = 0
  491. f = 0
  492. if f_max < f:
  493. f_max = f
  494. p_max = precision
  495. r_max = recall
  496. num_prots_train = np.sum(labels_train[i, :])
  497. height = get_height(go, functions[i])
  498. results.append([functions[i], num_prots_train, height, f_max, p_max, r_max])
  499. results = pd.DataFrame(results)
  500. results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False)
  501. def function_centric_performance_backup(functions, preds, labels, labels_train):
  502. results = []
  503. preds = np.round(preds, 2)
  504. for i in range(len(functions)):
  505. f_max = 0
  506. p_max = 0
  507. r_max = 0
  508. x = list()
  509. y = list()
  510. total = 0
  511. for t in range(1, 100):
  512. threshold = t / 100.0
  513. predictions = (preds[i, :] > threshold).astype(np.int32)
  514. tp = np.sum(predictions * labels[i, :])
  515. fp = np.sum(predictions) - tp
  516. fn = np.sum(labels[i, :]) - tp
  517. if tp >0:
  518. sn = tp / (1.0 * np.sum(labels[i, :]))
  519. sp = np.sum((predictions ^ 1) * (labels[i, :] ^ 1))
  520. sp /= 1.0 * np.sum(labels[i, :] ^ 1)
  521. fpr = 1 - sp
  522. x.append(fpr)
  523. y.append(sn)
  524. precision = tp / (1.0 * (tp + fp))
  525. recall = tp / (1.0 * (tp + fn))
  526. f = 2 * precision * recall / (precision + recall)
  527. total +=1
  528. if f_max < f:
  529. f_max = f
  530. p_max = precision
  531. r_max = recall
  532. num_prots = np.sum(labels[i, :])
  533. num_prots_train = np.sum(labels_train[i,:])
  534. if total >1 :
  535. roc_auc = auc(x, y)
  536. else:
  537. roc_auc =0
  538. height = get_height(go , functions[i])
  539. results.append([functions[i], f_max, p_max, r_max, num_prots, num_prots_train, height,roc_auc])
  540. results = pd.DataFrame(results)
  541. #results.to_csv('new_results.txt' , sep='\t' , index = False)
  542. results.to_csv('Con_GodGanSeq_results_'+FUNCTION +'.txt', sep='\t', index=False)
  543. #results = np.array(results)
  544. #p_mean = (np.sum(results[:,2])) / len(functions)
  545. #r_mean = (np.sum(results[:,3])) / len(functions)
  546. #f_mean = (2*p_mean*r_mean)/(p_mean+r_mean)
  547. #roc_auc_mean = (np.sum(results[:,7])) / len(functions)
  548. #print('Function centric performance (macro) ' '%f %f %f %f' % (f_mean, p_mean, r_mean, roc_auc_mean))
  549. def micro_macro_function_centric_f1_backup(preds, labels):
  550. preds = np.round(preds, 2)
  551. m_f1_max = 0
  552. M_f1_max = 0
  553. for t in range(1, 100):
  554. threshold = t / 100.0
  555. predictions = (preds > threshold).astype(np.int32)
  556. m_tp = 0
  557. m_fp = 0
  558. m_fn = 0
  559. M_pr = 0
  560. M_rc = 0
  561. total = 0
  562. p_total = 0
  563. for i in range(len(preds)):
  564. tp = np.sum(predictions[i, :] * labels[i, :])
  565. fp = np.sum(predictions[i, :]) - tp
  566. fn = np.sum(labels[i, :]) - tp
  567. if tp == 0 and fp == 0 and fn == 0:
  568. continue
  569. total += 1
  570. if tp > 0:
  571. pr = tp / (1.0 * (tp + fp))
  572. rc = tp / (1.0 * (tp + fn))
  573. m_tp += tp
  574. m_fp += fp
  575. m_fn += fn
  576. M_pr += pr
  577. M_rc += rc
  578. p_total += 1
  579. if p_total == 0:
  580. continue
  581. if total > 0:
  582. m_tp /= total
  583. m_fn /= total
  584. m_fp /= total
  585. m_pr = m_tp / (1.0 * (m_tp + m_fp))
  586. m_rc = m_tp / (1.0 * (m_tp + m_fn))
  587. M_pr /= p_total
  588. M_rc /= total
  589. m_f1 = 2 * m_pr * m_rc / (m_pr + m_rc)
  590. M_f1 = 2 * M_pr * M_rc / (M_pr + M_rc)
  591. if m_f1 > m_f1_max:
  592. m_f1_max = m_f1
  593. m_pr_max = m_pr
  594. m_rc_max = m_rc
  595. if M_f1 > M_f1_max:
  596. M_f1_max = M_f1
  597. M_pr_max = M_pr
  598. M_rc_max = M_rc
  599. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  600. def micro_macro_function_centric_f1(preds, labels):
  601. preds = np.round(preds, 2)
  602. m_f1_max = 0
  603. M_f1_max = 0
  604. for t in range(1, 200):
  605. threshold = t / 200.0
  606. predictions = (preds > threshold).astype(np.int32)
  607. m_tp = 0
  608. m_fp = 0
  609. m_fn = 0
  610. M_pr = 0
  611. M_rc = 0
  612. for i in range(preds.shape[0]):
  613. tp = np.sum(predictions[i, :] * labels[i, :])
  614. fp = np.sum(predictions[i, :]) - tp
  615. fn = np.sum(labels[i, :]) - tp
  616. m_tp += tp
  617. m_fp += fp
  618. m_fn += fn
  619. if tp > 0:
  620. pr = 1.0 * tp / (1.0 * (tp + fp))
  621. rc = 1.0 * tp / (1.0 * (tp + fn))
  622. else:
  623. if fp == 0 and fn == 0:
  624. pr = 1
  625. rc = 1
  626. else:
  627. pr = 0
  628. rc = 0
  629. M_pr += pr
  630. M_rc += rc
  631. if m_tp > 0:
  632. m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp))
  633. m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn))
  634. m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc)
  635. else:
  636. if m_fp == 0 and m_fn == 0:
  637. m_pr = 1
  638. m_rc = 1
  639. m_f1 = 1
  640. else:
  641. m_pr = 0
  642. m_rc = 0
  643. m_f1 = 0
  644. M_pr /= preds.shape[0]
  645. M_rc /= preds.shape[0]
  646. if M_pr == 0 and M_rc == 0:
  647. M_f1 = 0
  648. else:
  649. M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc)
  650. if m_f1 > m_f1_max:
  651. m_f1_max = m_f1
  652. m_pr_max = m_pr
  653. m_rc_max = m_rc
  654. if M_f1 > M_f1_max:
  655. M_f1_max = M_f1
  656. M_pr_max = M_pr
  657. M_rc_max = M_rc
  658. return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max
  659. def compute_roc(preds, labels):
  660. # Compute ROC curve and ROC area for each class
  661. fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
  662. roc_auc = auc(fpr, tpr)
  663. return roc_auc
  664. def compute_aupr(preds, labels):
  665. # Compute ROC curve and ROC area for each class
  666. pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten())
  667. pr_auc = auc(rc, pr)
  668. #pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' )
  669. M_pr_auc = 0
  670. return pr_auc, M_pr_auc
  671. def compute_mcc(preds, labels):
  672. # Compute ROC curve and ROC area for each class
  673. mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
  674. return mcc
  675. def compute_performance(preds, labels): #, gos):
  676. preds = np.round(preds, 2)
  677. f_max = 0
  678. p_max = 0
  679. r_max = 0
  680. t_max = 0
  681. for t in range(1, 100):
  682. threshold = t / 100.0
  683. predictions = (preds > threshold).astype(np.int32)
  684. total = 0
  685. f = 0.0
  686. p = 0.0
  687. r = 0.0
  688. p_total = 0
  689. for i in range(labels.shape[0]):
  690. tp = np.sum(predictions[i, :] * labels[i, :])
  691. fp = np.sum(predictions[i, :]) - tp
  692. fn = np.sum(labels[i, :]) - tp
  693. all_gos = set()
  694. #for go_id in gos[i]:
  695. # if go_id in all_functions:
  696. # all_gos |= get_anchestors(go, go_id)
  697. #all_gos.discard(GO_ID)
  698. #all_gos -= func_set
  699. #fn += len(all_gos)
  700. if tp == 0 and fp == 0 and fn == 0:
  701. continue
  702. total += 1
  703. if tp != 0:
  704. p_total += 1
  705. precision = tp / (1.0 * (tp + fp))
  706. recall = tp / (1.0 * (tp + fn))
  707. p += precision
  708. r += recall
  709. if p_total == 0:
  710. continue
  711. r /= total
  712. p /= p_total
  713. if p + r > 0:
  714. f = 2 * p * r / (p + r)
  715. if f_max < f:
  716. f_max = f
  717. p_max = p
  718. r_max = r
  719. t_max = threshold
  720. predictions_max = predictions
  721. return f_max, p_max, r_max, t_max, predictions_max
  722. def get_gos(pred):
  723. mdist = 1.0
  724. mgos = None
  725. for i in range(len(labels_gos)):
  726. labels, gos = labels_gos[i]
  727. dist = distance.cosine(pred, labels)
  728. if mdist > dist:
  729. mdist = dist
  730. mgos = gos
  731. return mgos
  732. def compute_similarity_performance(train_df, test_df, preds):
  733. logging.info("Computing similarity performance")
  734. logging.info("Training data size %d" % len(train_df))
  735. train_labels = train_df['labels'].values
  736. train_gos = train_df['gos'].values
  737. global labels_gos
  738. labels_gos = zip(train_labels, train_gos)
  739. p = Pool(64)
  740. pred_gos = p.map(get_gos, preds)
  741. total = 0
  742. p = 0.0
  743. r = 0.0
  744. f = 0.0
  745. test_gos = test_df['gos'].values
  746. for gos, tgos in zip(pred_gos, test_gos):
  747. preds = set()
  748. test = set()
  749. for go_id in gos:
  750. if go_id in all_functions:
  751. preds |= get_anchestors(go, go_id)
  752. for go_id in tgos:
  753. if go_id in all_functions:
  754. test |= get_anchestors(go, go_id)
  755. tp = len(preds.intersection(test))
  756. fp = len(preds - test)
  757. fn = len(test - preds)
  758. if tp == 0 and fp == 0 and fn == 0:
  759. continue
  760. total += 1
  761. if tp != 0:
  762. precision = tp / (1.0 * (tp + fp))
  763. recall = tp / (1.0 * (tp + fn))
  764. p += precision
  765. r += recall
  766. f += 2 * precision * recall / (precision + recall)
  767. return f / total, p / total, r / total
  768. def print_report(report, go_id):
  769. with open(DATA_ROOT + 'reports.txt', 'a') as f:
  770. f.write('Classification report for ' + go_id + '\n')
  771. f.write(report + '\n')
  772. if __name__ == '__main__':
  773. main()