You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

pfp_conditional_wgan_exp1.py 31KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. #!/usr/bin/env python
  2. """
  3. """
  4. from __future__ import division
  5. import logging
  6. import sys
  7. import time
  8. from collections import deque
  9. from multiprocessing import Pool
  10. import click as ck
  11. import numpy as np
  12. import pandas as pd
  13. import tensorflow as tf
  14. from keras import backend as K
  15. from keras.callbacks import EarlyStopping, ModelCheckpoint
  16. from keras.layers import (
  17. Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D, AveragePooling1D, Multiply,GaussianNoise,
  18. Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout)
  19. from keras.losses import binary_crossentropy
  20. from keras.models import Sequential, Model, load_model
  21. from keras.preprocessing import sequence
  22. from scipy.spatial import distance
  23. from sklearn.metrics import log_loss
  24. from sklearn.metrics import roc_curve, auc, matthews_corrcoef
  25. from keras.layers import Lambda, LeakyReLU, RepeatVector
  26. from sklearn.metrics import precision_recall_curve
  27. from keras.callbacks import TensorBoard
  28. from utils import (
  29. get_gene_ontology,
  30. get_go_set,
  31. get_anchestors,
  32. get_parents,
  33. DataGenerator,
  34. FUNC_DICT,
  35. get_height)
  36. from conditional_wgan_wrapper_exp1 import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new
  37. config = tf.ConfigProto()
  38. config.gpu_options.allow_growth = True
  39. sess = tf.Session(config=config)
  40. K.set_session(sess)
  41. logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
  42. sys.setrecursionlimit(100000)
  43. DATA_ROOT = 'Data/data1/'
  44. MAXLEN = 1000
  45. REPLEN = 256
  46. ind = 0
  47. @ck.command()
  48. @ck.option(
  49. '--function',
  50. default='bp',
  51. help='Ontology id (mf, bp, cc)')
  52. @ck.option(
  53. '--device',
  54. default='gpu:0',
  55. help='GPU or CPU device id')
  56. @ck.option(
  57. '--org',
  58. default= None,
  59. help='Organism id for filtering test set')
  60. @ck.option('--train', default=True, is_flag=True)
  61. def main(function, device, org, train):
  62. global FUNCTION
  63. FUNCTION = function
  64. global GO_ID
  65. GO_ID = FUNC_DICT[FUNCTION]
  66. global go
  67. go = get_gene_ontology('go.obo')
  68. global ORG
  69. ORG = org
  70. func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl')
  71. global functions
  72. functions = func_df['functions'].values
  73. global embedds
  74. embedds = load_go_embedding(functions)
  75. global all_functions
  76. all_functions = get_go_set(go, GO_ID)
  77. logging.info('Functions: %s %d' % (FUNCTION, len(functions)))
  78. if ORG is not None:
  79. logging.info('Organism %s' % ORG)
  80. global node_names
  81. node_names = set()
  82. with tf.device('/' + device):
  83. params = {
  84. 'fc_output': 1024,
  85. 'embedding_dims': 128,
  86. 'embedding_dropout': 0.2,
  87. 'nb_conv': 2,
  88. 'nb_dense': 1,
  89. 'filter_length': 128,
  90. 'nb_filter': 32,
  91. 'pool_length': 64,
  92. 'stride': 32
  93. }
  94. model(params, is_train=train)
  95. def load_go_embedding(functions):
  96. f_embedding = open('data/swiss/term-graph-cellular-component-100d.emb', 'r')
  97. f_index = open('data/swiss/term-id-cellular-component.txt','r')
  98. embeddings = f_embedding.readlines()
  99. index = f_index.readlines()
  100. dict1 ={}
  101. for i in range(1, len(index)):
  102. ind = map(str, index[i].split())
  103. dict1[ind[0]] = ind[2]
  104. dict2={}
  105. for i in range(1, len(embeddings)):
  106. embedding = map(str, embeddings[i].split())
  107. dict2[str(embedding[0])] = embedding[1:]
  108. embedds = []
  109. z = 0
  110. for i in range(len(functions)):
  111. term = functions[i]
  112. if term in dict1.keys():
  113. term_index = dict1[term]
  114. term_embedd = dict2[str(term_index)]
  115. else:
  116. term_embedd = np.ones([100])
  117. embedds.append(term_embedd)
  118. return np.array(embedds)
  119. def load_data():
  120. df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
  121. n = len(df)
  122. index = df.index.values
  123. valid_n = int(n * 0.8)
  124. train_df = df.loc[index[:valid_n]]
  125. valid_df = df.loc[index[valid_n:]]
  126. test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
  127. print( test_df['orgs'] )
  128. if ORG is not None:
  129. logging.info('Unfiltered test size: %d' % len(test_df))
  130. test_df = test_df[test_df['orgs'] == ORG]
  131. logging.info('Filtered test size: %d' % len(test_df))
  132. def reshape(values):
  133. values = np.hstack(values).reshape(
  134. len(values), len(values[0]))
  135. return values
  136. def normalize_minmax(values):
  137. mn = np.min(values)
  138. mx = np.max(values)
  139. if mx - mn != 0.0:
  140. return (values - mn) / (mx - mn)
  141. return values - mn
  142. def get_values(data_frame):
  143. print(data_frame['labels'].values.shape)
  144. labels = reshape(data_frame['labels'].values)
  145. ngrams = sequence.pad_sequences(
  146. data_frame['ngrams'].values, maxlen=MAXLEN)
  147. ngrams = reshape(ngrams)
  148. rep = reshape(data_frame['embeddings'].values)
  149. data = ngrams
  150. return data, labels
  151. train = get_values(train_df)
  152. valid = get_values(valid_df)
  153. test = get_values(test_df)
  154. return train, valid, test, train_df, valid_df, test_df
  155. def load_data0():
  156. df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl')
  157. n = len(df)
  158. index = df.index.values
  159. valid_n = int(n * 0.8)
  160. train_df = df.loc[index[:valid_n]]
  161. valid_df = df.loc[index[valid_n:]]
  162. test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl')
  163. print( test_df['orgs'] )
  164. if ORG is not None:
  165. logging.info('Unfiltered test size: %d' % len(test_df))
  166. test_df = test_df[test_df['orgs'] == ORG]
  167. logging.info('Filtered test size: %d' % len(test_df))
  168. def reshape(values):
  169. values = np.hstack(values).reshape(
  170. len(values), len(values[0]))
  171. return values
  172. def normalize_minmax(values):
  173. mn = np.min(values)
  174. mx = np.max(values)
  175. if mx - mn != 0.0:
  176. return (values - mn) / (mx - mn)
  177. return values - mn
  178. def get_values(data_frame, p_list):
  179. accesion_numbers = data_frame['accessions'].values
  180. all_labels = data_frame['labels'].values
  181. all_ngrams = reshape(sequence.pad_sequences(data_frame['ngrams'].values, maxlen=MAXLEN) )
  182. print(len(all_labels))
  183. print(len(p_list))
  184. labels = []
  185. data = []
  186. for i in range(accesion_numbers.shape[0]):
  187. if accesion_numbers[i] in p_list:
  188. labels.append(all_labels[i])
  189. data.append(all_ngrams[i])
  190. return np.asarray(data), np.asarray(labels)
  191. dev_F = pd.read_csv('dev-' + FUNCTION + '-input.tsv', sep='\t', header=0, index_col=0)
  192. test_F = pd.read_csv('test-' + FUNCTION + '-input.tsv', sep='\t', header=0, index_col=0)
  193. train_F = pd.read_csv('train-' + FUNCTION + '-input.tsv', sep='\t', header=0, index_col=0)
  194. train_l = train_F.index.tolist()
  195. test_l = test_F.index.tolist()
  196. dev_l = dev_F.index.tolist()
  197. train = get_values(train_df, train_l)
  198. valid = get_values(valid_df, dev_l)
  199. test = get_values(test_df, test_l)
  200. return train, valid, test, train_df, valid_df, test_df
  201. class TrainValTensorBoard(TensorBoard):
  202. def __init__(self, log_dir='./t_logs', **kwargs):
  203. # Make the original `TensorBoard` log to a subdirectory 'training'
  204. training_log_dir = log_dir
  205. super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
  206. # Log the validation metrics to a separate subdirectory
  207. self.val_log_dir = './v_logs'
  208. self.dis_log_dir = './d_logs'
  209. self.gen_log_dir = './g_logs'
  210. def set_model(self, model):
  211. # Setup writer for validation metrics
  212. self.val_writer = tf.summary.FileWriter(self.val_log_dir)
  213. self.dis_writer = tf.summary.FileWriter(self.dis_log_dir)
  214. self.gen_writer = tf.summary.FileWriter(self.gen_log_dir)
  215. super(TrainValTensorBoard, self).set_model(model)
  216. def on_epoch_end(self, epoch, logs=None):
  217. # Pop the validation logs and handle them separately with
  218. # `self.val_writer`. Also rename the keys so that they can
  219. # be plotted on the same figure with the training metrics
  220. logs = logs or {}
  221. val_logs = {k.replace('val_', 'v_'): v for k, v in logs.items() if k.startswith('val_')}
  222. for name, value in val_logs.items():
  223. summary = tf.Summary()
  224. summary_value = summary.value.add()
  225. summary_value.simple_value = value.item()
  226. summary_value.tag = name
  227. self.val_writer.add_summary(summary, epoch)
  228. self.val_writer.flush()
  229. logs = logs or {}
  230. dis_logs = {k.replace('discriminator_', 'd_'): v for k, v in logs.items() if k.startswith('discriminator_')}
  231. for name, value in dis_logs.items():
  232. summary = tf.Summary()
  233. summary_value = summary.value.add()
  234. summary_value.simple_value = value.item()
  235. summary_value.tag = name
  236. self.dis_writer.add_summary(summary, epoch)
  237. self.dis_writer.flush()
  238. logs = logs or {}
  239. gen_logs = {k.replace('generator_', 'g_'): v for k, v in logs.items() if k.startswith('generator_')}
  240. for name, value in gen_logs.items():
  241. summary = tf.Summary()
  242. summary_value = summary.value.add()
  243. summary_value.simple_value = value.item()
  244. summary_value.tag = name
  245. self.gen_writer.add_summary(summary, epoch)
  246. self.gen_writer.flush()
  247. # Pass the remaining logs to `TensorBoard.on_epoch_end`
  248. t_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
  249. tr_logs = {k: v for k, v in t_logs.items() if not k.startswith('discriminator_')}
  250. tra_logs = {k: v for k, v in tr_logs.items() if not k.startswith('generator_')}
  251. super(TrainValTensorBoard, self).on_epoch_end(epoch, tra_logs)
  252. def on_train_end(self, logs=None):
  253. super(TrainValTensorBoard, self).on_train_end(logs)
  254. self.val_writer.close()
  255. self.gen_writer.close()
  256. self.dis_writer.close()
  257. def get_feature_model(params):
  258. embedding_dims = params['embedding_dims']
  259. max_features = 8001
  260. model = Sequential()
  261. model.add(Embedding(
  262. max_features,
  263. embedding_dims,
  264. input_length=MAXLEN))
  265. model.add(SpatialDropout1D(0.4))
  266. model.add(Conv1D(
  267. padding="valid",
  268. strides=1,
  269. filters=params['nb_filter'],
  270. kernel_size=params['filter_length']))
  271. model.add(LeakyReLU())
  272. #model.add(BatchNormalization())
  273. model.add(Conv1D(
  274. padding="valid",
  275. strides=1,
  276. filters=params['nb_filter'],
  277. kernel_size=params['filter_length']))
  278. model.add(LeakyReLU())
  279. # model.add(BatchNormalization()) problem in loading the best model
  280. model.add(AveragePooling1D(strides=params['stride'], pool_size=params['pool_length']))
  281. model.add(Flatten())
  282. return model
  283. #def get_node_name(go_id, unique=False):
  284. # name = go_id.split(':')[1]
  285. # if not unique:
  286. # return name
  287. # if name not in node_names:
  288. # node_names.add(name)
  289. # return name
  290. # i = 1
  291. # while (name + '_' + str(i)) in node_names:
  292. # i += 1
  293. # name = name + '_' + str(i)
  294. # node_names.add(name)
  295. # return name
  296. #def get_layers(inputs):
  297. # q = deque()
  298. # layers = {}
  299. # name = get_node_name(GO_ID)
  300. # layers[GO_ID] = {'net': inputs}
  301. # for node_id in go[GO_ID]['children']:
  302. # if node_id in func_set:
  303. # q.append((node_id, inputs))
  304. # while len(q) > 0:
  305. # node_id, net = q.popleft()
  306. # name = get_node_name(node_id)
  307. # net, output = get_function_node(name, inputs)
  308. # if node_id not in layers:
  309. # layers[node_id] = {'net': net, 'output': output}
  310. # for n_id in go[node_id]['children']:
  311. # if n_id in func_set and n_id not in layers:
  312. # ok = True
  313. # for p_id in get_parents(go, n_id):
  314. # if p_id in func_set and p_id not in layers:
  315. # ok = False
  316. # if ok:
  317. # q.append((n_id, net))
  318. # for node_id in functions:
  319. # childs = set(go[node_id]['children']).intersection(func_set)
  320. # if len(childs) > 0:
  321. # outputs = [layers[node_id]['output']]
  322. # for ch_id in childs:
  323. # outputs.append(layers[ch_id]['output'])
  324. # name = get_node_name(node_id) + '_max'
  325. # layers[node_id]['output'] = Maximum(name=name)(outputs)
  326. # return layers
  327. #def get_function_node(name, inputs):
  328. # output_name = name + '_out'
  329. # output = Dense(1, name=output_name, activation='sigmoid')(inputs)
  330. # return output, output
  331. def get_generator(params, n_classes):
  332. inputs = Input(shape=(MAXLEN,), dtype='int32', name='input1')
  333. feature_model = get_feature_model(params)(inputs)
  334. net = Dense(200)(feature_model) #bp(dis1 g1) 400 - cc 200 - mf 200
  335. net = BatchNormalization()(net)
  336. net = LeakyReLU()(net)
  337. #To impose DAG hierarchy directly to the model (inspired by DeepGO)
  338. #layers = get_layers(net)
  339. #output_models = []
  340. #for i in range(len(functions)):
  341. # output_models.append(layers[functions[i]]['output'])
  342. #net = Concatenate(axis=1)(output_models)
  343. output = Dense(n_classes, activation='tanh')(net) #'sigmoid'
  344. model = Model(inputs=inputs, outputs=output)
  345. return model
  346. def get_discriminator(params, n_classes, dropout_rate=0.5):
  347. inputs = Input(shape=(n_classes, ))
  348. #n_inputs = GaussianNoise(0.1)(inputs)
  349. #x = Conv1D(filters=8, kernel_size=128, padding='valid', strides=1)(n_inputs)
  350. #x = BatchNormalization()(x)
  351. #x = LeakyReLU()(x)
  352. #x = Conv1D(filters=8, kernel_size=128, padding='valid', strides=1)(x)
  353. #x = BatchNormalization()(x)
  354. #x = LeakyReLU()(x)
  355. #x = AveragePooling1D(strides=params['stride'], pool_size=params['pool_length'])(x)
  356. #x = Flatten()(x)
  357. #x = Lambda(lambda y: K.squeeze(y, 2))(x)
  358. inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2')
  359. x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2)
  360. x2 = SpatialDropout1D(0.4)(x2)
  361. x2 = Conv1D(filters=8, kernel_size=128, padding='valid', strides=1)(x2)
  362. x2 = BatchNormalization()(x2)
  363. x2 = LeakyReLU()(x2)
  364. #x2 = Conv1D(filters=1, kernel_size=1, padding='valid', strides=1)(x2)
  365. #x2 = BatchNormalization()(x2)
  366. #x2 = LeakyReLU()(x2)
  367. #x2 = AveragePooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2)
  368. x2 = Flatten()(x2)
  369. size = 100
  370. x2 = Dropout(dropout_rate)(x2)
  371. x2 = Dense(size)(x2)
  372. x2 = BatchNormalization()(x2)
  373. x2 = LeakyReLU()(x2)
  374. size = 100
  375. x = Dropout(dropout_rate)(inputs)
  376. x = Dense(size)(x)
  377. x = BatchNormalization()(x)
  378. x = LeakyReLU()(x)
  379. x = Concatenate(axis=1, name='merged2')([x, x2])
  380. layer_sizes = [30, 30, 30, 30, 30, 30]
  381. for size in layer_sizes:
  382. x = Dropout(dropout_rate)(x)
  383. x = Dense(size)(x)
  384. x = BatchNormalization()(x)
  385. x = LeakyReLU()(x)
  386. outputs = Dense(1)(x)
  387. model = Model(inputs = [inputs, inputs2], outputs=[outputs], name='Discriminator')
  388. return model
  389. def rescale_labels(labels):
  390. rescaled_labels = 2*(labels - 0.5)
  391. return rescaled_labels
  392. def rescale_back_labels(labels):
  393. rescaled_back_labels =(labels/2)+0.5
  394. return rescaled_back_labels
  395. def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10):
  396. generator = get_generator(params, nb_classes)
  397. discriminator = get_discriminator(params, nb_classes)
  398. generator_model, discriminator_model = \
  399. WGAN_wrapper(generator=generator,
  400. discriminator=discriminator,
  401. generator_input_shape=(MAXLEN,),
  402. discriminator_input_shape=(nb_classes,),
  403. discriminator_input_shape2 = (MAXLEN, ),
  404. batch_size=batch_size,
  405. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT,
  406. embeddings=embedds)
  407. logging.info('Compilation finished')
  408. return generator_model, discriminator_model
  409. def train_wgan(generator_model, discriminator_model, batch_size, epochs,
  410. x_train, y_train, x_val, y_val, x_test, y_test, generator_model_path, discriminator_model_path,
  411. TRAINING_RATIO=10, N_WARM_UP=20):
  412. BATCH_SIZE = batch_size
  413. N_EPOCH = epochs
  414. positive_y = np.ones((batch_size, 1), dtype=np.float32)
  415. zero_y = positive_y * 0
  416. negative_y = -positive_y
  417. positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32)
  418. dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
  419. positive_full_enable_train = np.ones((len(x_train), 1), dtype=np.float32)
  420. positive_full_enable_val = np.ones((len(x_val), 1), dtype=np.float32)
  421. positive_full_enable_test = np.ones((len(x_test), 1), dtype=np.float32)
  422. best_validation_loss = None
  423. callback = TrainValTensorBoard(write_graph=False) #TensorBoard(log_path)
  424. callback.set_model(generator_model)
  425. for epoch in range(N_EPOCH):
  426. # np.random.shuffle(X_train)
  427. print("Epoch: ", epoch)
  428. print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE))
  429. discriminator_loss = []
  430. generator_loss = []
  431. minibatches_size = BATCH_SIZE * TRAINING_RATIO
  432. shuffled_indexes = np.random.permutation(x_train.shape[0])
  433. shuffled_indexes_2 = np.random.permutation(x_train.shape[0])
  434. for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))):
  435. batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size]
  436. batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size]
  437. x = x_train[batch_indexes]
  438. y = y_train[batch_indexes]
  439. y_2 = y_train[batch_indexes_2]
  440. x_2 = x_train[batch_indexes_2]
  441. if epoch < N_WARM_UP:
  442. for j in range(TRAINING_RATIO):
  443. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  444. y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  445. generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y]))
  446. discriminator_loss.append(0)
  447. else:
  448. for j in range(TRAINING_RATIO):
  449. x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  450. y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  451. y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  452. x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
  453. # noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
  454. noise = x_batch
  455. discriminator_loss.append(discriminator_model.train_on_batch(
  456. [y_batch_2, noise, x_batch_2],
  457. [positive_y, negative_y, dummy_y]))
  458. generator_loss.append(generator_model.train_on_batch([x, positive_full_y], [y, positive_full_y]))
  459. predicted_y_train = generator_model.predict([x_train, positive_full_enable_train], batch_size=BATCH_SIZE)[0]
  460. predicted_y_val = generator_model.predict([x_val, positive_full_enable_val], batch_size=BATCH_SIZE)[0]
  461. train_loss = log_loss(rescale_back_labels(y_train), rescale_back_labels(predicted_y_train))
  462. val_loss = log_loss(y_val, rescale_back_labels(predicted_y_val))
  463. callback.on_epoch_end(epoch,
  464. logs={'discriminator_loss': (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1),
  465. 'generator_loss': (np.sum(np.asarray(generator_loss)) if generator_loss else -1),#train_loss,
  466. 'train_bce_loss': train_loss,
  467. 'val_bce_loss': val_loss})
  468. print("train bce loss: {:.4f}, validation bce loss: {:.4f}, generator loss: {:.4f}, discriminator loss: {:.4f}".format(
  469. train_loss, val_loss,
  470. (np.sum(np.asarray(generator_loss)) if generator_loss else -1) / x_train.shape[0],
  471. (np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0]))
  472. if best_validation_loss is None or best_validation_loss > val_loss:
  473. print('\nEpoch %05d: improved from %0.5f,'
  474. ' saving model to %s and %s'
  475. % (epoch + 1, val_loss, generator_model_path, discriminator_model_path))
  476. best_validation_loss = val_loss
  477. generator_model.save(generator_model_path, overwrite=True)
  478. discriminator_model.save(discriminator_model_path, overwrite=True)
  479. def model(params, batch_size=64, nb_epoch=10000, is_train=True):
  480. nb_classes = len(functions)
  481. logging.info("Loading Data")
  482. train, val, test, train_df, valid_df, test_df = load_data()
  483. train_data, train_labels = train
  484. val_data, val_labels = val
  485. test_data, test_labels = test
  486. train_labels = rescale_labels(train_labels)
  487. logging.info("Training data size: %d" % len(train_data))
  488. logging.info("Validation data size: %d" % len(val_data))
  489. logging.info("Test data size: %d" % len(test_data))
  490. generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5'
  491. discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5'
  492. logging.info('Starting training the model')
  493. train_generator = DataGenerator(batch_size, nb_classes)
  494. train_generator.fit(train_data, train_labels)
  495. valid_generator = DataGenerator(batch_size, nb_classes)
  496. valid_generator.fit(val_data, val_labels)
  497. test_generator = DataGenerator(batch_size, nb_classes)
  498. test_generator.fit(test_data, test_labels)
  499. if is_train:
  500. generator_model, discriminator_model = get_model(params, nb_classes, batch_size)
  501. #Loading a pretrained model
  502. #generator_model.load_weights(DATA_ROOT + 'models/trained_bce/new_model_seq_' + FUNCTION + '.h5')
  503. train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch,
  504. x_train=train_data, y_train=train_labels , x_val=val_data, y_val=val_labels, x_test = test_data, y_test = test_labels,
  505. generator_model_path=generator_model_path,
  506. discriminator_model_path=discriminator_model_path)
  507. logging.info('Loading best model')
  508. model = load_model(generator_model_path,
  509. custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new,
  510. 'wasserstein_loss': wasserstein_loss})
  511. logging.info('Predicting')
  512. start_time = time.time()
  513. preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0]
  514. end_time = time.time()
  515. logging.info('%f prediction time for %f protein sequences' % (start_time-end_time, len(test_data)))
  516. preds = rescale_back_labels(preds)
  517. #preds = propagate_score(preds, functions)
  518. logging.info('Computing performance')
  519. f, p, r, t, preds_max = compute_performance(preds, test_labels)
  520. micro_roc_auc = compute_roc(preds, test_labels)
  521. macro_roc_auc = compute_roc_macro(preds, test_labels)
  522. mcc = compute_mcc(preds_max, test_labels)
  523. aupr, _ = compute_aupr(preds, test_labels)
  524. tpr_score, total = compute_tpr_score(preds_max, functions)
  525. m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T)
  526. function_centric_performance(functions, preds.T, test_labels.T, rescale_back_labels(train_labels).T, t)
  527. logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f))
  528. logging.info('Micro ROC AUC: \t %f ' % (micro_roc_auc, ))
  529. logging.info('Macro ROC AUC: \t %f ' % (macro_roc_auc,))
  530. logging.info('MCC: \t %f ' % (mcc, ))
  531. logging.info('AUPR: \t %f ' % (aupr, ))
  532. logging.info('TPR_Score from total: \t %f \t %f ' % (tpr_score,total))
  533. logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) )
  534. logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) )
  535. def function_centric_performance(functions, preds, labels, labels_train, threshold):
  536. results = []
  537. preds = np.round(preds, 2)
  538. for i in range(preds.shape[0]):
  539. predictions = (preds[i, :] > threshold).astype(np.int32)
  540. tp = np.sum(predictions * labels[i, :])
  541. fp = np.sum(predictions) - tp
  542. fn = np.sum(labels[i, :]) - tp
  543. if tp > 0:
  544. precision = tp / (1.0 * (tp + fp))
  545. recall = tp / (1.0 * (tp + fn))
  546. f = 2 * precision * recall / (precision + recall)
  547. else:
  548. if fp == 0 and fn == 0:
  549. precision = 1
  550. recall = 1
  551. f = 1
  552. else:
  553. precision = 0
  554. recall = 0
  555. f = 0
  556. f_max = f
  557. p_max = precision
  558. r_max = recall
  559. num_prots_train = np.sum(labels_train[i, :])
  560. height = get_height(go, functions[i])
  561. results.append([functions[i], num_prots_train, height, f_max, p_max, r_max])
  562. results = pd.DataFrame(results)
  563. results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False)
  564. def micro_macro_function_centric_f1(preds, labels):
  565. predictions = np.round(preds)
  566. m_tp = 0
  567. m_fp = 0
  568. m_fn = 0
  569. M_pr = 0
  570. M_rc = 0
  571. for i in range(preds.shape[0]):
  572. tp = np.sum(predictions[i, :] * labels[i, :])
  573. fp = np.sum(predictions[i, :]) - tp
  574. fn = np.sum(labels[i, :]) - tp
  575. m_tp += tp
  576. m_fp += fp
  577. m_fn += fn
  578. if tp > 0:
  579. pr = 1.0 * tp / (1.0 * (tp + fp))
  580. rc = 1.0 * tp / (1.0 * (tp + fn))
  581. else:
  582. if fp == 0 and fn == 0:
  583. pr = 1
  584. rc = 1
  585. else:
  586. pr = 0
  587. rc = 0
  588. M_pr += pr
  589. M_rc += rc
  590. if m_tp > 0:
  591. m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp))
  592. m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn))
  593. m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc)
  594. else:
  595. if m_fp == 0 and m_fn == 0:
  596. m_pr = 1
  597. m_rc = 1
  598. m_f1 = 1
  599. else:
  600. m_pr = 0
  601. m_rc = 0
  602. m_f1 = 0
  603. M_pr /= preds.shape[0]
  604. M_rc /= preds.shape[0]
  605. if M_pr == 0 and M_rc == 0:
  606. M_f1 = 0
  607. else:
  608. M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc)
  609. return m_pr, m_rc, m_f1, M_pr, M_rc, M_f1
  610. def compute_roc(preds, labels):
  611. # Compute ROC curve and ROC area for each class
  612. fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
  613. roc_auc = auc(fpr, tpr)
  614. return roc_auc
  615. def compute_roc_macro(yhat_raw, y):
  616. if yhat_raw.shape[0] <= 1:
  617. return
  618. fpr = {}
  619. tpr = {}
  620. roc_auc = {}
  621. # get AUC for each label individually
  622. relevant_labels = []
  623. auc_labels = {}
  624. for i in range(y.shape[1]):
  625. # only if there are true positives for this label
  626. if y[:, i].sum() > 0:
  627. fpr[i], tpr[i], _ = roc_curve(y[:, i], yhat_raw[:, i])
  628. if len(fpr[i]) > 1 and len(tpr[i]) > 1:
  629. auc_score = auc(fpr[i], tpr[i])
  630. if not np.isnan(auc_score):
  631. auc_labels["auc_%d" % i] = auc_score
  632. relevant_labels.append(i)
  633. # macro-AUC: just average the auc scores
  634. aucs = []
  635. for i in relevant_labels:
  636. aucs.append(auc_labels['auc_%d' % i])
  637. roc_auc_macro = np.mean(aucs)
  638. return roc_auc_macro
  639. def compute_aupr(preds, labels):
  640. # Compute ROC curve and ROC area for each class
  641. pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten())
  642. pr_auc = auc(rc, pr)
  643. M_pr_auc = 0
  644. return pr_auc, M_pr_auc
  645. def compute_mcc(preds, labels):
  646. # Compute ROC curve and ROC area for each class
  647. mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
  648. return mcc
  649. def compute_performance(preds, labels):
  650. predicts = np.round(preds, 2)
  651. f_max = 0
  652. p_max = 0
  653. r_max = 0
  654. t_max = 0
  655. for t in range(1, 100):
  656. threshold = t / 100.0
  657. predictions = (predicts > threshold).astype(np.int32)
  658. total = 0
  659. f = 0.0
  660. p = 0.0
  661. r = 0.0
  662. p_total = 0
  663. for i in range(labels.shape[0]):
  664. tp = np.sum(predictions[i, :] * labels[i, :])
  665. fp = np.sum(predictions[i, :]) - tp
  666. fn = np.sum(labels[i, :]) - tp
  667. if tp == 0 and fp == 0 and fn == 0:
  668. continue
  669. total += 1
  670. if tp != 0:
  671. p_total += 1
  672. precision = tp / (1.0 * (tp + fp))
  673. recall = tp / (1.0 * (tp + fn))
  674. p += precision
  675. r += recall
  676. if p_total == 0:
  677. continue
  678. r /= total
  679. p /= p_total
  680. if p + r > 0:
  681. f = 2 * p * r / (p + r)
  682. if f_max < f:
  683. f_max = f
  684. p_max = p
  685. r_max = r
  686. t_max = threshold
  687. predictions_max = predictions
  688. return f_max, p_max, r_max, t_max, predictions_max
  689. def get_anchestors(go, go_id):
  690. go_set = set()
  691. q = deque()
  692. q.append(go_id)
  693. while(len(q) > 0):
  694. g_id = q.popleft()
  695. go_set.add(g_id)
  696. for parent_id in go[g_id]['is_a']:
  697. if parent_id in go:
  698. q.append(parent_id)
  699. return go_set
  700. def compute_tpr_score(preds_max, functions):
  701. my_dict = {}
  702. for i in range(len(functions)):
  703. my_dict[functions[i]] = i
  704. total =0
  705. anc = np.zeros([len(functions), len(functions)])
  706. for i in range(len(functions)):
  707. anchestors = list(get_anchestors(go, functions[i]))
  708. for j in range(len(anchestors)):
  709. if anchestors[j] in my_dict.keys():
  710. anc[i, my_dict[anchestors[j]]]=1
  711. total +=1
  712. tpr_score = 0
  713. for i in range(preds_max.shape[0]):
  714. for j in range(preds_max.shape[1]):
  715. if preds_max[i, j] == 1:
  716. for k in range(anc.shape[1]):
  717. if anc[j, k]==1:
  718. if preds_max[i,k]!=1:
  719. tpr_score = tpr_score +1
  720. return tpr_score/preds_max.shape[0], total
  721. def propagate_score(preds, functions):
  722. my_dict = {}
  723. for i in range(len(functions)):
  724. my_dict[functions[i]] = i
  725. anc = np.zeros([len(functions), len(functions)])
  726. for i in range(len(functions)):
  727. anchestors = list(get_anchestors(go, functions[i]))
  728. for j in range(len(anchestors)):
  729. if anchestors[j] in my_dict.keys():
  730. anc[i, my_dict[anchestors[j]]] = 1
  731. new_preds = np.array(preds)
  732. for i in range(new_preds.shape[1]):
  733. for k in range(anc.shape[1]):
  734. if anc[i, k] == 1:
  735. new_preds[new_preds[:, i] > new_preds[:, k], k] = new_preds[new_preds[:, i] > new_preds[:, k], i]
  736. return new_preds
  737. if __name__ == '__main__':
  738. main()