@@ -1,902 +0,0 @@ | |||
#!/usr/bin/env python | |||
""" | |||
parts of codes are based on the work in https://github.com/bio-ontology-research-group/deepgo. | |||
""" | |||
from __future__ import division | |||
import logging | |||
import sys | |||
import time | |||
from collections import deque | |||
from multiprocessing import Pool | |||
import click as ck | |||
import numpy as np | |||
import pandas as pd | |||
import tensorflow as tf | |||
from keras import backend as K | |||
from keras.callbacks import EarlyStopping, ModelCheckpoint | |||
from keras.layers import ( | |||
Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D, | |||
Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout) | |||
from keras.losses import binary_crossentropy | |||
from keras.models import Sequential, Model, load_model | |||
from keras.preprocessing import sequence | |||
from scipy.spatial import distance | |||
from sklearn.metrics import log_loss | |||
from sklearn.metrics import roc_curve, auc, matthews_corrcoef | |||
from keras.layers import Lambda | |||
from sklearn.metrics import precision_recall_curve | |||
from utils import ( | |||
get_gene_ontology, | |||
get_go_set, | |||
get_anchestors, | |||
get_parents, | |||
DataGenerator, | |||
FUNC_DICT, | |||
get_height, | |||
get_ipro) | |||
from conditional_wgan_wrapper_post import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new | |||
config = tf.ConfigProto() | |||
config.gpu_options.allow_growth = True | |||
sess = tf.Session(config=config) | |||
K.set_session(sess) | |||
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO) | |||
sys.setrecursionlimit(100000) | |||
DATA_ROOT = 'data/swiss/' | |||
MAXLEN = 1000 | |||
REPLEN = 256 | |||
ind = 0 | |||
@ck.command() | |||
@ck.option( | |||
'--function', | |||
default='bp', | |||
help='Ontology id (mf, bp, cc)') | |||
@ck.option( | |||
'--device', | |||
default='gpu:0', | |||
help='GPU or CPU device id') | |||
@ck.option( | |||
'--org', | |||
default= None, | |||
help='Organism id for filtering test set') | |||
@ck.option('--train',default = True, is_flag=True) | |||
@ck.option('--param', default=0, help='Param index 0-7') | |||
def main(function, device, org, train, param): | |||
global FUNCTION | |||
FUNCTION = function | |||
global GO_ID | |||
GO_ID = FUNC_DICT[FUNCTION] | |||
global go | |||
go = get_gene_ontology('go.obo') | |||
global ORG | |||
ORG = org | |||
func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl') | |||
global functions | |||
functions = func_df['functions'].values | |||
global func_set | |||
func_set = set(functions) | |||
global all_functions | |||
all_functions = get_go_set(go, GO_ID) | |||
logging.info('Functions: %s %d' % (FUNCTION, len(functions))) | |||
if ORG is not None: | |||
logging.info('Organism %s' % ORG) | |||
global go_indexes | |||
go_indexes = dict() | |||
for ind, go_id in enumerate(functions): | |||
go_indexes[go_id] = ind | |||
global node_names | |||
node_names = set() | |||
with tf.device('/' + device): | |||
params = { | |||
'fc_output': 1024, | |||
'learning_rate': 0.001, | |||
'embedding_dims': 128, | |||
'embedding_dropout': 0.2, | |||
'nb_conv': 1, | |||
'nb_dense': 1, | |||
'filter_length': 128, | |||
'nb_filter': 32, | |||
'pool_length': 64, | |||
'stride': 32 | |||
} | |||
model(params, is_train=train) | |||
def load_data(): | |||
df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl') | |||
n = len(df) | |||
index = df.index.values | |||
valid_n = int(n * 0.8) | |||
train_df = df.loc[index[:valid_n]] | |||
valid_df = df.loc[index[valid_n:]] | |||
test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl') | |||
print( test_df['orgs'] ) | |||
if ORG is not None: | |||
logging.info('Unfiltered test size: %d' % len(test_df)) | |||
test_df = test_df[test_df['orgs'] == ORG] | |||
logging.info('Filtered test size: %d' % len(test_df)) | |||
# Filter by type | |||
# org_df = pd.read_pickle('data/prokaryotes.pkl') | |||
# orgs = org_df['orgs'] | |||
# test_df = test_df[test_df['orgs'].isin(orgs)] | |||
def reshape(values): | |||
values = np.hstack(values).reshape( | |||
len(values), len(values[0])) | |||
return values | |||
def normalize_minmax(values): | |||
mn = np.min(values) | |||
mx = np.max(values) | |||
if mx - mn != 0.0: | |||
return (values - mn) / (mx - mn) | |||
return values - mn | |||
def get_values(data_frame): | |||
print(data_frame['labels'].values.shape) | |||
labels = reshape(data_frame['labels'].values) | |||
ngrams = sequence.pad_sequences( | |||
data_frame['ngrams'].values, maxlen=MAXLEN) | |||
ngrams = reshape(ngrams) | |||
rep = reshape(data_frame['embeddings'].values) | |||
data = ngrams | |||
return data, labels | |||
train = get_values(train_df) | |||
valid = get_values(valid_df) | |||
test = get_values(test_df) | |||
return train, valid, test, train_df, valid_df, test_df | |||
def get_feature_model(params): | |||
embedding_dims = params['embedding_dims'] | |||
max_features = 8001 | |||
model = Sequential() | |||
model.add(Embedding( | |||
max_features, | |||
embedding_dims, | |||
input_length=MAXLEN)) | |||
model.add(SpatialDropout1D(0.4)) | |||
for i in range(params['nb_conv']): | |||
model.add(Conv1D( | |||
activation="relu", | |||
padding="valid", | |||
strides=1, | |||
filters=params['nb_filter'], | |||
kernel_size=params['filter_length'])) | |||
model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])) | |||
model.add(Flatten()) | |||
return model | |||
def merge_outputs(outputs, name): | |||
if len(outputs) == 1: | |||
return outputs[0] | |||
## return merge(outputs, mode='concat', name=name, concat_axis=1) | |||
return Concatenate(axis=1, name=name)(outputs) | |||
def merge_nets(nets, name): | |||
if len(nets) == 1: | |||
return nets[0] | |||
## return merge(nets, mode='sum', name=name) | |||
return Add(name=name)(nets) | |||
def get_node_name(go_id, unique=False): | |||
name = go_id.split(':')[1] | |||
if not unique: | |||
return name | |||
if name not in node_names: | |||
node_names.add(name) | |||
return name | |||
i = 1 | |||
while (name + '_' + str(i)) in node_names: | |||
i += 1 | |||
name = name + '_' + str(i) | |||
node_names.add(name) | |||
return name | |||
def get_layers(inputs): | |||
q = deque() | |||
layers = {} | |||
name = get_node_name(GO_ID) | |||
layers[GO_ID] = {'net': inputs} | |||
for node_id in go[GO_ID]['children']: | |||
if node_id in func_set: | |||
q.append((node_id, inputs)) | |||
while len(q) > 0: | |||
node_id, net = q.popleft() | |||
parent_nets = [inputs] | |||
# for p_id in get_parents(go, node_id): | |||
# if p_id in func_set: | |||
# parent_nets.append(layers[p_id]['net']) | |||
# if len(parent_nets) > 1: | |||
# name = get_node_name(node_id) + '_parents' | |||
# net = merge( | |||
# parent_nets, mode='concat', concat_axis=1, name=name) | |||
name = get_node_name(node_id) | |||
net, output = get_function_node(name, inputs) | |||
if node_id not in layers: | |||
layers[node_id] = {'net': net, 'output': output} | |||
for n_id in go[node_id]['children']: | |||
if n_id in func_set and n_id not in layers: | |||
ok = True | |||
for p_id in get_parents(go, n_id): | |||
if p_id in func_set and p_id not in layers: | |||
ok = False | |||
if ok: | |||
q.append((n_id, net)) | |||
for node_id in functions: | |||
childs = set(go[node_id]['children']).intersection(func_set) | |||
if len(childs) > 0: | |||
outputs = [layers[node_id]['output']] | |||
for ch_id in childs: | |||
outputs.append(layers[ch_id]['output']) | |||
name = get_node_name(node_id) + '_max' | |||
## layers[node_id]['output'] = merge( | |||
## outputs, mode='max', name=name) | |||
layers[node_id]['output'] = Maximum(name=name)(outputs) | |||
return layers | |||
def get_function_node(name, inputs): | |||
output_name = name + '_out' | |||
# net = Dense(256, name=name, activation='relu')(inputs) | |||
output = Dense(1, name=output_name, activation='sigmoid')(inputs) | |||
return output, output | |||
def get_generator(params, n_classes): | |||
inputs = Input(shape=(MAXLEN,), dtype='int32', name='input1') | |||
feature_model = get_feature_model(params)(inputs) | |||
net = Dense(300, activation='relu')(feature_model) | |||
net = BatchNormalization()(net) | |||
layers = get_layers(net) | |||
output_models = [] | |||
for i in range(len(functions)): | |||
output_models.append(layers[functions[i]]['output']) | |||
net = Concatenate(axis=1)(output_models) | |||
output = Dense(n_classes, activation='sigmoid')(net) | |||
model = Model(inputs=inputs, outputs=output) | |||
return model | |||
def get_discriminator(params, n_classes, dropout_rate=0.5): | |||
inputs = Input(shape=(n_classes, )) | |||
inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2') | |||
x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2) | |||
x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', activation ='relu', strides=1)(x2) | |||
x2 = Lambda(lambda x: K.squeeze(x, 2))(x2) | |||
#for i in range(params['nb_conv']): | |||
# x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2) | |||
#x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2) | |||
#x2 = Flatten()(x2) | |||
size = 40 | |||
x = inputs | |||
x = Dropout(dropout_rate)(x) | |||
x = Dense(size)(x) | |||
x = BatchNormalization()(x) | |||
x = Activation('relu')(x) | |||
size = 40 | |||
x2 = Dropout(dropout_rate)(x2) | |||
x2 = Dense(size)(x2) | |||
x2 = BatchNormalization()(x2) | |||
x2 = Activation('relu')(x2) | |||
x = Concatenate(axis =1 , name = 'merged2')([x, x2]) | |||
layer_sizes = [80, 40,30] | |||
for size in layer_sizes: | |||
x = Dropout(dropout_rate)(x) | |||
x = Dense(size)(x) | |||
x = BatchNormalization()(x) | |||
x = Activation('relu')(x) | |||
outputs = Dense(1)(x) | |||
model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator') | |||
return model | |||
def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10): | |||
generator = get_generator(params, nb_classes) | |||
discriminator = get_discriminator(params, nb_classes) | |||
generator_model, discriminator_model = \ | |||
WGAN_wrapper(generator=generator, | |||
discriminator=discriminator, | |||
generator_input_shape=(MAXLEN,), | |||
discriminator_input_shape=(nb_classes,), | |||
discriminator_input_shape2 = (MAXLEN, ), | |||
batch_size=batch_size, | |||
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) | |||
logging.info('Compilation finished') | |||
return generator_model, discriminator_model | |||
def train_wgan(generator_model, discriminator_model, batch_size, epochs, | |||
x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path, | |||
TRAINING_RATIO=10, N_WARM_UP=0): | |||
BATCH_SIZE = batch_size | |||
N_EPOCH = epochs | |||
positive_y = np.ones((batch_size, 1), dtype=np.float32) | |||
zero_y = positive_y * 0 | |||
negative_y = -positive_y | |||
positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32) | |||
dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32) | |||
positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 ) | |||
positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 ) | |||
#positive_enable_train = np.ones((1, batch_size),dtype = np.float32 ) | |||
#positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 ) | |||
best_validation_loss = None | |||
for epoch in range(N_EPOCH): | |||
# np.random.shuffle(X_train) | |||
print("Epoch: ", epoch) | |||
print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE)) | |||
discriminator_loss = [] | |||
generator_loss = [] | |||
minibatches_size = BATCH_SIZE * TRAINING_RATIO | |||
shuffled_indexes = np.random.permutation(x_train.shape[0]) | |||
shuffled_indexes_2 = np.random.permutation(x_train.shape[0]) | |||
for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))): | |||
batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size] | |||
batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size] | |||
x = x_train[batch_indexes] | |||
y = y_train[batch_indexes] | |||
y_2 = y_train[batch_indexes_2] | |||
x_2 = x_train[batch_indexes_2] | |||
if epoch < N_WARM_UP: | |||
for j in range(TRAINING_RATIO): | |||
x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y])) | |||
else: | |||
for j in range(TRAINING_RATIO): | |||
x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
# noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32) | |||
noise = x_batch | |||
#print(sum(y_batch_2)) | |||
discriminator_loss.append(discriminator_model.train_on_batch( | |||
[y_batch_2, noise, x_batch_2 ], | |||
[positive_y, negative_y, dummy_y])) | |||
generator_loss.append(generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y])) | |||
# Still needs some code to display losses from the generator and discriminator, progress bars, etc. | |||
predicted_y_train, _ = generator_model.predict([x_train , positive_full_enable_train], batch_size=BATCH_SIZE) | |||
predicted_y_val, _ = generator_model.predict([ x_val , positive_full_enable_val ], batch_size=BATCH_SIZE) | |||
#print(sum(sum(positive_full_enable_train))) | |||
#print(predicted_y_train) | |||
train_loss = log_loss(y_train, predicted_y_train) | |||
val_loss = log_loss(y_val, predicted_y_val) | |||
print("train loss: {:.4f}, validation loss: {:.4f}, discriminator loss: {:.4f}".format( | |||
train_loss, val_loss, | |||
(np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0])) | |||
if best_validation_loss is None or best_validation_loss > val_loss: | |||
print('\nEpoch %05d: improved from %0.5f,' | |||
' saving model to %s and %s' | |||
% (epoch + 1, val_loss, generator_model_path, discriminator_model_path)) | |||
best_validation_loss = val_loss | |||
generator_model.save(generator_model_path, overwrite=True) | |||
discriminator_model.save(discriminator_model_path, overwrite=True) | |||
def model(params, batch_size=20, nb_epoch=40, is_train=True): | |||
# set parameters: | |||
nb_classes = len(functions) | |||
start_time = time.time() | |||
logging.info("Loading Data") | |||
train, val, test, train_df, valid_df, test_df = load_data() | |||
train_df = pd.concat([train_df, valid_df]) | |||
test_gos = test_df['gos'].values | |||
train_data, train_labels = train | |||
val_data, val_labels = val | |||
test_data, test_labels = test | |||
logging.info("Data loaded in %d sec" % (time.time() - start_time)) | |||
logging.info("Training data size: %d" % len(train_data)) | |||
logging.info("Validation data size: %d" % len(val_data)) | |||
logging.info("Test data size: %d" % len(test_data)) | |||
generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5' | |||
discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5' | |||
logging.info('Starting training the model') | |||
train_generator = DataGenerator(batch_size, nb_classes) | |||
train_generator.fit(train_data, train_labels) | |||
valid_generator = DataGenerator(batch_size, nb_classes) | |||
valid_generator.fit(val_data, val_labels) | |||
test_generator = DataGenerator(batch_size, nb_classes) | |||
test_generator.fit(test_data, test_labels) | |||
if is_train: | |||
generator_model, discriminator_model = get_model(params, nb_classes, batch_size) | |||
train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch, | |||
x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels, | |||
generator_model_path=generator_model_path, | |||
discriminator_model_path=discriminator_model_path) | |||
logging.info('Loading best model') | |||
model = load_model(generator_model_path, | |||
custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new, | |||
'wasserstein_loss': wasserstein_loss}) | |||
logging.info('Predicting') | |||
preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0] | |||
logging.info('Computing performance') | |||
f, p, r, t, preds_max = compute_performance(preds, test_labels) #, test_gos) | |||
roc_auc = compute_roc(preds, test_labels) | |||
mcc = compute_mcc(preds_max, test_labels) | |||
aupr , _ = compute_aupr(preds, test_labels) | |||
m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T) | |||
logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f)) | |||
logging.info('ROC AUC: \t %f ' % (roc_auc, )) | |||
logging.info('MCC: \t %f ' % (mcc, )) | |||
logging.info('AUPR: \t %f ' % (aupr, )) | |||
logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) ) | |||
logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) ) | |||
function_centric_performance(functions, preds.T, test_labels.T, train_labels.T) | |||
def load_prot_ipro(): | |||
proteins = list() | |||
ipros = list() | |||
with open(DATA_ROOT + 'swissprot_ipro.tab') as f: | |||
for line in f: | |||
it = line.strip().split('\t') | |||
if len(it) != 3: | |||
continue | |||
prot = it[1] | |||
iprs = it[2].split(';') | |||
proteins.append(prot) | |||
ipros.append(iprs) | |||
return pd.DataFrame({'proteins': proteins, 'ipros': ipros}) | |||
def performanc_by_interpro(): | |||
pred_df = pd.read_pickle(DATA_ROOT + 'test-' + FUNCTION + '-preds.pkl') | |||
ipro_df = load_prot_ipro() | |||
df = pred_df.merge(ipro_df, on='proteins', how='left') | |||
ipro = get_ipro() | |||
def reshape(values): | |||
values = np.hstack(values).reshape( | |||
len(values), len(values[0])) | |||
return values | |||
for ipro_id in ipro: | |||
if len(ipro[ipro_id]['parents']) > 0: | |||
continue | |||
labels = list() | |||
predictions = list() | |||
gos = list() | |||
for i, row in df.iterrows(): | |||
if not isinstance(row['ipros'], list): | |||
continue | |||
if ipro_id in row['ipros']: | |||
labels.append(row['labels']) | |||
predictions.append(row['predictions']) | |||
gos.append(row['gos']) | |||
pr = 0 | |||
rc = 0 | |||
total = 0 | |||
p_total = 0 | |||
for i in range(len(labels)): | |||
tp = np.sum(labels[i] * predictions[i]) | |||
fp = np.sum(predictions[i]) - tp | |||
fn = np.sum(labels[i]) - tp | |||
all_gos = set() | |||
for go_id in gos[i]: | |||
if go_id in all_functions: | |||
all_gos |= get_anchestors(go, go_id) | |||
all_gos.discard(GO_ID) | |||
all_gos -= func_set | |||
fn += len(all_gos) | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp != 0: | |||
p_total += 1 | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
pr += precision | |||
rc += recall | |||
if total > 0 and p_total > 0: | |||
rc /= total | |||
pr /= p_total | |||
if pr + rc > 0: | |||
f = 2 * pr * rc / (pr + rc) | |||
logging.info('%s\t%d\t%f\t%f\t%f' % ( | |||
ipro_id, len(labels), f, pr, rc)) | |||
def function_centric_performance(functions, preds, labels, labels_train): | |||
results = [] | |||
preds = np.round(preds, 2) | |||
for i in range(preds.shape[0]): | |||
f_max = 0 | |||
p_max = 0 | |||
r_max = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds[i, :] > threshold).astype(np.int32) | |||
tp = np.sum(predictions * labels[i, :]) | |||
fp = np.sum(predictions) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
if tp > 0: | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
f = 2 * precision * recall / (precision + recall) | |||
else: | |||
if fp == 0 and fn == 0: | |||
precision = 1 | |||
recall = 1 | |||
f = 1 | |||
else: | |||
precision = 0 | |||
recall = 0 | |||
f = 0 | |||
if f_max < f: | |||
f_max = f | |||
p_max = precision | |||
r_max = recall | |||
num_prots_train = np.sum(labels_train[i, :]) | |||
height = get_height(go, functions[i]) | |||
results.append([functions[i], num_prots_train, height, f_max, p_max, r_max]) | |||
results = pd.DataFrame(results) | |||
results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False) | |||
def function_centric_performance_backup(functions, preds, labels, labels_train): | |||
results = [] | |||
preds = np.round(preds, 2) | |||
for i in range(len(functions)): | |||
f_max = 0 | |||
p_max = 0 | |||
r_max = 0 | |||
x = list() | |||
y = list() | |||
total = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds[i, :] > threshold).astype(np.int32) | |||
tp = np.sum(predictions * labels[i, :]) | |||
fp = np.sum(predictions) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
if tp >0: | |||
sn = tp / (1.0 * np.sum(labels[i, :])) | |||
sp = np.sum((predictions ^ 1) * (labels[i, :] ^ 1)) | |||
sp /= 1.0 * np.sum(labels[i, :] ^ 1) | |||
fpr = 1 - sp | |||
x.append(fpr) | |||
y.append(sn) | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
f = 2 * precision * recall / (precision + recall) | |||
total +=1 | |||
if f_max < f: | |||
f_max = f | |||
p_max = precision | |||
r_max = recall | |||
num_prots = np.sum(labels[i, :]) | |||
num_prots_train = np.sum(labels_train[i,:]) | |||
if total >1 : | |||
roc_auc = auc(x, y) | |||
else: | |||
roc_auc =0 | |||
height = get_height(go , functions[i]) | |||
results.append([functions[i], f_max, p_max, r_max, num_prots, num_prots_train, height,roc_auc]) | |||
results = pd.DataFrame(results) | |||
#results.to_csv('new_results.txt' , sep='\t' , index = False) | |||
results.to_csv('Con_GodGanSeq_results_'+FUNCTION +'.txt', sep='\t', index=False) | |||
#results = np.array(results) | |||
#p_mean = (np.sum(results[:,2])) / len(functions) | |||
#r_mean = (np.sum(results[:,3])) / len(functions) | |||
#f_mean = (2*p_mean*r_mean)/(p_mean+r_mean) | |||
#roc_auc_mean = (np.sum(results[:,7])) / len(functions) | |||
#print('Function centric performance (macro) ' '%f %f %f %f' % (f_mean, p_mean, r_mean, roc_auc_mean)) | |||
def micro_macro_function_centric_f1_backup(preds, labels): | |||
preds = np.round(preds, 2) | |||
m_f1_max = 0 | |||
M_f1_max = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds > threshold).astype(np.int32) | |||
m_tp = 0 | |||
m_fp = 0 | |||
m_fn = 0 | |||
M_pr = 0 | |||
M_rc = 0 | |||
total = 0 | |||
p_total = 0 | |||
for i in range(len(preds)): | |||
tp = np.sum(predictions[i, :] * labels[i, :]) | |||
fp = np.sum(predictions[i, :]) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp > 0: | |||
pr = tp / (1.0 * (tp + fp)) | |||
rc = tp / (1.0 * (tp + fn)) | |||
m_tp += tp | |||
m_fp += fp | |||
m_fn += fn | |||
M_pr += pr | |||
M_rc += rc | |||
p_total += 1 | |||
if p_total == 0: | |||
continue | |||
if total > 0: | |||
m_tp /= total | |||
m_fn /= total | |||
m_fp /= total | |||
m_pr = m_tp / (1.0 * (m_tp + m_fp)) | |||
m_rc = m_tp / (1.0 * (m_tp + m_fn)) | |||
M_pr /= p_total | |||
M_rc /= total | |||
m_f1 = 2 * m_pr * m_rc / (m_pr + m_rc) | |||
M_f1 = 2 * M_pr * M_rc / (M_pr + M_rc) | |||
if m_f1 > m_f1_max: | |||
m_f1_max = m_f1 | |||
m_pr_max = m_pr | |||
m_rc_max = m_rc | |||
if M_f1 > M_f1_max: | |||
M_f1_max = M_f1 | |||
M_pr_max = M_pr | |||
M_rc_max = M_rc | |||
return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max | |||
def micro_macro_function_centric_f1(preds, labels): | |||
preds = np.round(preds, 2) | |||
m_f1_max = 0 | |||
M_f1_max = 0 | |||
for t in range(1, 200): | |||
threshold = t / 200.0 | |||
predictions = (preds > threshold).astype(np.int32) | |||
m_tp = 0 | |||
m_fp = 0 | |||
m_fn = 0 | |||
M_pr = 0 | |||
M_rc = 0 | |||
for i in range(preds.shape[0]): | |||
tp = np.sum(predictions[i, :] * labels[i, :]) | |||
fp = np.sum(predictions[i, :]) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
m_tp += tp | |||
m_fp += fp | |||
m_fn += fn | |||
if tp > 0: | |||
pr = 1.0 * tp / (1.0 * (tp + fp)) | |||
rc = 1.0 * tp / (1.0 * (tp + fn)) | |||
else: | |||
if fp == 0 and fn == 0: | |||
pr = 1 | |||
rc = 1 | |||
else: | |||
pr = 0 | |||
rc = 0 | |||
M_pr += pr | |||
M_rc += rc | |||
if m_tp > 0: | |||
m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp)) | |||
m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn)) | |||
m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc) | |||
else: | |||
if m_fp == 0 and m_fn == 0: | |||
m_pr = 1 | |||
m_rc = 1 | |||
m_f1 = 1 | |||
else: | |||
m_pr = 0 | |||
m_rc = 0 | |||
m_f1 = 0 | |||
M_pr /= preds.shape[0] | |||
M_rc /= preds.shape[0] | |||
if M_pr == 0 and M_rc == 0: | |||
M_f1 = 0 | |||
else: | |||
M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc) | |||
if m_f1 > m_f1_max: | |||
m_f1_max = m_f1 | |||
m_pr_max = m_pr | |||
m_rc_max = m_rc | |||
if M_f1 > M_f1_max: | |||
M_f1_max = M_f1 | |||
M_pr_max = M_pr | |||
M_rc_max = M_rc | |||
return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max | |||
def compute_roc(preds, labels): | |||
# Compute ROC curve and ROC area for each class | |||
fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten()) | |||
roc_auc = auc(fpr, tpr) | |||
return roc_auc | |||
def compute_aupr(preds, labels): | |||
# Compute ROC curve and ROC area for each class | |||
pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten()) | |||
pr_auc = auc(rc, pr) | |||
#pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' ) | |||
M_pr_auc = 0 | |||
return pr_auc, M_pr_auc | |||
def compute_mcc(preds, labels): | |||
# Compute ROC curve and ROC area for each class | |||
mcc = matthews_corrcoef(labels.flatten(), preds.flatten()) | |||
return mcc | |||
def compute_performance(preds, labels): #, gos): | |||
preds = np.round(preds, 2) | |||
f_max = 0 | |||
p_max = 0 | |||
r_max = 0 | |||
t_max = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds > threshold).astype(np.int32) | |||
total = 0 | |||
f = 0.0 | |||
p = 0.0 | |||
r = 0.0 | |||
p_total = 0 | |||
for i in range(labels.shape[0]): | |||
tp = np.sum(predictions[i, :] * labels[i, :]) | |||
fp = np.sum(predictions[i, :]) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
all_gos = set() | |||
#for go_id in gos[i]: | |||
# if go_id in all_functions: | |||
# all_gos |= get_anchestors(go, go_id) | |||
#all_gos.discard(GO_ID) | |||
#all_gos -= func_set | |||
#fn += len(all_gos) | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp != 0: | |||
p_total += 1 | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
p += precision | |||
r += recall | |||
if p_total == 0: | |||
continue | |||
r /= total | |||
p /= p_total | |||
if p + r > 0: | |||
f = 2 * p * r / (p + r) | |||
if f_max < f: | |||
f_max = f | |||
p_max = p | |||
r_max = r | |||
t_max = threshold | |||
predictions_max = predictions | |||
return f_max, p_max, r_max, t_max, predictions_max | |||
def get_gos(pred): | |||
mdist = 1.0 | |||
mgos = None | |||
for i in range(len(labels_gos)): | |||
labels, gos = labels_gos[i] | |||
dist = distance.cosine(pred, labels) | |||
if mdist > dist: | |||
mdist = dist | |||
mgos = gos | |||
return mgos | |||
def compute_similarity_performance(train_df, test_df, preds): | |||
logging.info("Computing similarity performance") | |||
logging.info("Training data size %d" % len(train_df)) | |||
train_labels = train_df['labels'].values | |||
train_gos = train_df['gos'].values | |||
global labels_gos | |||
labels_gos = zip(train_labels, train_gos) | |||
p = Pool(64) | |||
pred_gos = p.map(get_gos, preds) | |||
total = 0 | |||
p = 0.0 | |||
r = 0.0 | |||
f = 0.0 | |||
test_gos = test_df['gos'].values | |||
for gos, tgos in zip(pred_gos, test_gos): | |||
preds = set() | |||
test = set() | |||
for go_id in gos: | |||
if go_id in all_functions: | |||
preds |= get_anchestors(go, go_id) | |||
for go_id in tgos: | |||
if go_id in all_functions: | |||
test |= get_anchestors(go, go_id) | |||
tp = len(preds.intersection(test)) | |||
fp = len(preds - test) | |||
fn = len(test - preds) | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp != 0: | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
p += precision | |||
r += recall | |||
f += 2 * precision * recall / (precision + recall) | |||
return f / total, p / total, r / total | |||
def print_report(report, go_id): | |||
with open(DATA_ROOT + 'reports.txt', 'a') as f: | |||
f.write('Classification report for ' + go_id + '\n') | |||
f.write(report + '\n') | |||
if __name__ == '__main__': | |||
main() |
@@ -1,917 +0,0 @@ | |||
#!/usr/bin/env python | |||
""" | |||
""" | |||
from __future__ import division | |||
import logging | |||
import sys | |||
import time | |||
from collections import deque | |||
from multiprocessing import Pool | |||
import click as ck | |||
import numpy as np | |||
import pandas as pd | |||
import tensorflow as tf | |||
from keras import backend as K | |||
from keras.callbacks import EarlyStopping, ModelCheckpoint | |||
from keras.layers import ( | |||
Dense, Input, SpatialDropout1D, Conv1D, MaxPooling1D, | |||
Flatten, Concatenate, Add, Maximum, Embedding, BatchNormalization, Activation, Dropout) | |||
from keras.losses import binary_crossentropy | |||
from keras.models import Sequential, Model, load_model | |||
from keras.preprocessing import sequence | |||
from scipy.spatial import distance | |||
from sklearn.metrics import log_loss | |||
from sklearn.metrics import roc_curve, auc, matthews_corrcoef | |||
from keras.layers import Lambda | |||
from sklearn.metrics import precision_recall_curve | |||
from utils import ( | |||
get_gene_ontology, | |||
get_go_set, | |||
get_anchestors, | |||
get_parents, | |||
DataGenerator, | |||
FUNC_DICT, | |||
get_height, | |||
get_ipro) | |||
from conditional_wgan_wrapper_post import WGAN_wrapper, wasserstein_loss, generator_recunstruction_loss_new | |||
config = tf.ConfigProto() | |||
config.gpu_options.allow_growth = True | |||
sess = tf.Session(config=config) | |||
K.set_session(sess) | |||
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO) | |||
sys.setrecursionlimit(100000) | |||
DATA_ROOT = 'data/swiss/' | |||
MAXLEN = 258 #1000 | |||
REPLEN = 256 | |||
ind = 0 | |||
@ck.command() | |||
@ck.option( | |||
'--function', | |||
default='bp', | |||
help='Ontology id (mf, bp, cc)') | |||
@ck.option( | |||
'--device', | |||
default='gpu:0', | |||
help='GPU or CPU device id') | |||
@ck.option( | |||
'--org', | |||
default= None, | |||
help='Organism id for filtering test set') | |||
@ck.option('--train',default = True, is_flag=True) | |||
@ck.option('--param', default=0, help='Param index 0-7') | |||
def main(function, device, org, train, param): | |||
global FUNCTION | |||
FUNCTION = function | |||
global GO_ID | |||
GO_ID = FUNC_DICT[FUNCTION] | |||
global go | |||
go = get_gene_ontology('go.obo') | |||
global ORG | |||
ORG = org | |||
func_df = pd.read_pickle(DATA_ROOT + FUNCTION + '.pkl') | |||
global functions | |||
functions = func_df['functions'].values | |||
global func_set | |||
func_set = set(functions) | |||
global all_functions | |||
all_functions = get_go_set(go, GO_ID) | |||
logging.info('Functions: %s %d' % (FUNCTION, len(functions))) | |||
if ORG is not None: | |||
logging.info('Organism %s' % ORG) | |||
global go_indexes | |||
go_indexes = dict() | |||
for ind, go_id in enumerate(functions): | |||
go_indexes[go_id] = ind | |||
global node_names | |||
node_names = set() | |||
with tf.device('/' + device): | |||
params = { | |||
'fc_output': 1024, | |||
'learning_rate': 0.001, | |||
'embedding_dims': 128, | |||
'embedding_dropout': 0.2, | |||
'nb_conv': 1, | |||
'nb_dense': 1, | |||
'filter_length': 128, | |||
'nb_filter': 32, | |||
'pool_length': 64, | |||
'stride': 32 | |||
} | |||
model(params, is_train=train) | |||
#dims = [64, 128, 256, 512] | |||
#nb_filters = [16, 32, 64, 128] | |||
#nb_convs = [1, 2, 3, 4] | |||
#nb_dense = [1, 2, 3, 4] | |||
#for i in range(param * 32, param * 32 + 32): | |||
# dim = i % 4 | |||
# i = i / 4 | |||
# nb_fil = i % 4 | |||
# i /= 4 | |||
# conv = i % 4 | |||
# i /= 4 | |||
# den = i | |||
# params['embedding_dims'] = dims[dim] | |||
# params['nb_filter'] = nb_filters[nb_fil] | |||
# params['nb_conv'] = nb_convs[conv] | |||
# params['nb_dense'] = nb_dense[den] | |||
# performanc_by_interpro() | |||
def load_data2(): | |||
all_data_x_fn = 'data2/all_data_X.csv' | |||
all_data_x = pd.read_csv(all_data_x_fn, sep='\t', header=0, index_col=0) | |||
all_proteins_train = [p.replace('"', '') for p in all_data_x.index] | |||
all_data_x.index = all_proteins_train | |||
all_data_y_fn = 'data2/all_data_Y.csv' | |||
all_data_y = pd.read_csv(all_data_y_fn, sep='\t', header=0, index_col=0) | |||
branch = pd.read_csv('data2/'+FUNCTION +'_branches.txt', sep='\t', header=0, index_col=0) | |||
all_x = all_data_x.values | |||
branches = [p for p in branch.index.tolist() if p in all_data_y.columns.tolist()] | |||
t= pd.DataFrame(all_data_y, columns=branches) | |||
all_y = t.values | |||
number_of_test = int(np.ceil(0.2 * len(all_x))) | |||
index = np.random.rand(1,number_of_test) | |||
index_test = [int(p) for p in np.ceil(index*len(all_x))[0] ] | |||
index_train = [p for p in range(len(all_x)) if p not in index_test] | |||
train_data = all_x[index_train, : ] #[ :20000, : ] | |||
test_data = all_x[index_test, : ] #[20000: , : ] | |||
train_labels = all_y[index_train, : ] #[ :20000, : ] | |||
test_labels = all_y[index_test, :] #[20000: , : ] | |||
val_data = test_data | |||
val_labels = test_labels | |||
#print(sum(sum(train_labels))) | |||
#print(train_data.shape) | |||
print(train_labels.shape) | |||
print(test_labels.shape) | |||
return train_data, train_labels, test_data, test_labels, val_data, val_labels | |||
def load_data(): | |||
df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl') | |||
n = len(df) | |||
index = df.index.values | |||
valid_n = int(n * 0.8) | |||
train_df = df.loc[index[:valid_n]] | |||
valid_df = df.loc[index[valid_n:]] | |||
test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl') | |||
print( test_df['orgs'] ) | |||
if ORG is not None: | |||
logging.info('Unfiltered test size: %d' % len(test_df)) | |||
test_df = test_df[test_df['orgs'] == ORG] | |||
logging.info('Filtered test size: %d' % len(test_df)) | |||
# Filter by type | |||
# org_df = pd.read_pickle('data/prokaryotes.pkl') | |||
# orgs = org_df['orgs'] | |||
# test_df = test_df[test_df['orgs'].isin(orgs)] | |||
def reshape(values): | |||
values = np.hstack(values).reshape( | |||
len(values), len(values[0])) | |||
return values | |||
def normalize_minmax(values): | |||
mn = np.min(values) | |||
mx = np.max(values) | |||
if mx - mn != 0.0: | |||
return (values - mn) / (mx - mn) | |||
return values - mn | |||
def get_values(data_frame): | |||
print(data_frame['labels'].values.shape) | |||
labels = reshape(data_frame['labels'].values) | |||
ngrams = sequence.pad_sequences( | |||
data_frame['ngrams'].values, maxlen=MAXLEN) | |||
ngrams = reshape(ngrams) | |||
rep = reshape(data_frame['embeddings'].values) | |||
data = ngrams | |||
return data, labels | |||
train = get_values(train_df) | |||
valid = get_values(valid_df) | |||
test = get_values(test_df) | |||
return train, valid, test, train_df, valid_df, test_df | |||
def get_feature_model(params): | |||
embedding_dims = params['embedding_dims'] | |||
max_features = 8001 | |||
model = Sequential() | |||
model.add(Embedding( | |||
max_features, | |||
embedding_dims, | |||
input_length=MAXLEN)) | |||
model.add(SpatialDropout1D(0.4)) | |||
for i in range(params['nb_conv']): | |||
model.add(Conv1D( | |||
activation="relu", | |||
padding="valid", | |||
strides=1, | |||
filters=params['nb_filter'], | |||
kernel_size=params['filter_length'])) | |||
model.add(MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])) | |||
model.add(Flatten()) | |||
return model | |||
def merge_outputs(outputs, name): | |||
if len(outputs) == 1: | |||
return outputs[0] | |||
## return merge(outputs, mode='concat', name=name, concat_axis=1) | |||
return Concatenate(axis=1, name=name)(outputs) | |||
def merge_nets(nets, name): | |||
if len(nets) == 1: | |||
return nets[0] | |||
## return merge(nets, mode='sum', name=name) | |||
return Add(name=name)(nets) | |||
def get_node_name(go_id, unique=False): | |||
name = go_id.split(':')[1] | |||
if not unique: | |||
return name | |||
if name not in node_names: | |||
node_names.add(name) | |||
return name | |||
i = 1 | |||
while (name + '_' + str(i)) in node_names: | |||
i += 1 | |||
name = name + '_' + str(i) | |||
node_names.add(name) | |||
return name | |||
def get_function_node(name, inputs): | |||
output_name = name + '_out' | |||
# net = Dense(256, name=name, activation='relu')(inputs) | |||
output = Dense(1, name=output_name, activation='sigmoid')(inputs) | |||
return output, output | |||
def get_generator(params, n_classes): | |||
inputs = Input(shape=(MAXLEN,), dtype='float32', name='input1') | |||
#feature_model = get_feature_model(params)(inputs) | |||
net0 = Dense(150, activation='relu')(inputs) | |||
net0 = Dense(150, activation='relu')(net0) | |||
#net0 = Dense(50, activation='relu')(net0) | |||
net = Dense(70, activation = 'relu')(net0) | |||
output = Dense(n_classes, activation='sigmoid')(net) | |||
model = Model(inputs=inputs, outputs=output) | |||
return model | |||
def get_discriminator(params, n_classes, dropout_rate=0.5): | |||
inputs = Input(shape=(n_classes, )) | |||
inputs2 = Input(shape =(MAXLEN,), dtype ='int32', name='d_input2') | |||
x2 = Embedding(8001,128, input_length=MAXLEN)(inputs2) | |||
x2 = Conv1D(filters =1 , kernel_size= 1, padding = 'valid', activation ='relu', strides=1)(x2) | |||
x2 = Lambda(lambda x: K.squeeze(x, 2))(x2) | |||
#for i in range(params['nb_conv']): | |||
# x2 = Conv1D ( activation="relu", padding="valid", strides=1, filters=params['nb_filter'],kernel_size=params['filter_length'])(x2) | |||
#x2 =MaxPooling1D(strides=params['stride'], pool_size=params['pool_length'])(x2) | |||
#x2 = Flatten()(x2) | |||
size = 40 | |||
x = inputs | |||
x = Dropout(dropout_rate)(x) | |||
x = Dense(size)(x) | |||
x = BatchNormalization()(x) | |||
x = Activation('relu')(x) | |||
size = 40 | |||
x2 = Dropout(dropout_rate)(x2) | |||
x2 = Dense(size)(x2) | |||
x2 = BatchNormalization()(x2) | |||
x2 = Activation('relu')(x2) | |||
x = Concatenate(axis =1 , name = 'merged2')([x, x2]) | |||
layer_sizes = [80, 40,30] | |||
for size in layer_sizes: | |||
x = Dropout(dropout_rate)(x) | |||
x = Dense(size)(x) | |||
x = BatchNormalization()(x) | |||
x = Activation('relu')(x) | |||
outputs = Dense(1)(x) | |||
model = Model(inputs = [inputs ,inputs2], outputs=outputs, name='Discriminator') | |||
return model | |||
def get_model(params,nb_classes, batch_size, GRADIENT_PENALTY_WEIGHT=10): | |||
generator = get_generator(params, nb_classes) | |||
discriminator = get_discriminator(params, nb_classes) | |||
generator_model, discriminator_model = \ | |||
WGAN_wrapper(generator=generator, | |||
discriminator=discriminator, | |||
generator_input_shape=(MAXLEN,), | |||
discriminator_input_shape=(nb_classes,), | |||
discriminator_input_shape2 = (MAXLEN, ), | |||
batch_size=batch_size, | |||
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) | |||
logging.info('Compilation finished') | |||
return generator_model, discriminator_model | |||
def train_wgan(generator_model, discriminator_model, batch_size, epochs, | |||
x_train, y_train, x_val, y_val, generator_model_path, discriminator_model_path, | |||
TRAINING_RATIO=10, N_WARM_UP=0): | |||
BATCH_SIZE = batch_size | |||
N_EPOCH = epochs | |||
positive_y = np.ones((batch_size, 1), dtype=np.float32) | |||
zero_y = positive_y * 0 | |||
negative_y = -positive_y | |||
positive_full_y = np.ones((BATCH_SIZE * TRAINING_RATIO, 1), dtype=np.float32) | |||
dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32) | |||
positive_full_enable_train = np.ones((len(x_train), 1), dtype = np.float32 ) | |||
positive_full_enable_val = np.ones((len(x_val), 1), dtype =np.float32 ) | |||
#positive_enable_train = np.ones((1, batch_size),dtype = np.float32 ) | |||
#positive_full_train_enable = np.ones((1,BATCH_SIZE * TRAINING_RATIO ), dtype=np.float32 ) | |||
best_validation_loss = None | |||
for epoch in range(N_EPOCH): | |||
# np.random.shuffle(X_train) | |||
print("Epoch: ", epoch) | |||
print("Number of batches: ", int(y_train.shape[0] // BATCH_SIZE)) | |||
discriminator_loss = [] | |||
generator_loss = [] | |||
minibatches_size = BATCH_SIZE * TRAINING_RATIO | |||
shuffled_indexes = np.random.permutation(x_train.shape[0]) | |||
shuffled_indexes_2 = np.random.permutation(x_train.shape[0]) | |||
for i in range(int(y_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))): | |||
batch_indexes = shuffled_indexes[i * minibatches_size:(i + 1) * minibatches_size] | |||
batch_indexes_2 = shuffled_indexes_2[i * minibatches_size:(i + 1) * minibatches_size] | |||
x = x_train[batch_indexes] | |||
y = y_train[batch_indexes] | |||
y_2 = y_train[batch_indexes_2] | |||
x_2 = x_train[batch_indexes_2] | |||
if epoch < N_WARM_UP: | |||
for j in range(TRAINING_RATIO): | |||
x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
y_batch = y[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
generator_loss.append(generator_model.train_on_batch([x_batch, positive_y], [y_batch, zero_y])) | |||
else: | |||
for j in range(TRAINING_RATIO): | |||
x_batch = x[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
y_batch_2 = y_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
x_batch_2 = x_2[j * BATCH_SIZE:(j + 1) * BATCH_SIZE] | |||
# noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32) | |||
noise = x_batch | |||
#print(sum(y_batch_2)) | |||
discriminator_loss.append(discriminator_model.train_on_batch( | |||
[y_batch_2, noise, x_batch_2 ], | |||
[positive_y, negative_y, dummy_y])) | |||
generator_loss.append(generator_model.train_on_batch([x,positive_full_y], [y, positive_full_y])) | |||
# Still needs some code to display losses from the generator and discriminator, progress bars, etc. | |||
predicted_y_train, _ = generator_model.predict([x_train , positive_full_enable_train], batch_size=BATCH_SIZE) | |||
predicted_y_val, _ = generator_model.predict([ x_val , positive_full_enable_val ], batch_size=BATCH_SIZE) | |||
#print(sum(sum(positive_full_enable_train))) | |||
#print(predicted_y_train) | |||
train_loss = log_loss(y_train, predicted_y_train) | |||
val_loss = log_loss(y_val, predicted_y_val) | |||
print("train loss: {:.4f}, validation loss: {:.4f}, discriminator loss: {:.4f}".format( | |||
train_loss, val_loss, | |||
(np.sum(np.asarray(discriminator_loss)) if discriminator_loss else -1) / x_train.shape[0])) | |||
if best_validation_loss is None or best_validation_loss > val_loss: | |||
print('\nEpoch %05d: improved from %0.5f,' | |||
' saving model to %s and %s' | |||
% (epoch + 1, val_loss, generator_model_path, discriminator_model_path)) | |||
best_validation_loss = val_loss | |||
generator_model.save(generator_model_path, overwrite=True) | |||
discriminator_model.save(discriminator_model_path, overwrite=True) | |||
def model(params, batch_size=20, nb_epoch=40, is_train=True): | |||
# set parameters: | |||
#nb_classes = len(functions) | |||
start_time = time.time() | |||
logging.info("Loading Data") | |||
## | |||
#train, val, test, train_df, valid_df, test_df = load_data() | |||
#train_df = pd.concat([train_df, valid_df]) | |||
#test_gos = test_df['gos'].values | |||
#train_data, train_labels = train | |||
#val_data, val_labels = val | |||
#test_data, test_labels = test | |||
## | |||
train_data, train_labels, test_data, test_labels, val_data, val_labels = load_data2() | |||
nb_classes = train_labels.shape[1] | |||
logging.info("Data loaded in %d sec" % (time.time() - start_time)) | |||
logging.info("Training data size: %d" % len(train_data)) | |||
logging.info("Validation data size: %d" % len(val_data)) | |||
logging.info("Test data size: %d" % len(test_data)) | |||
generator_model_path = DATA_ROOT + 'models/new_model_seq_' + FUNCTION + '.h5' | |||
discriminator_model_path = DATA_ROOT + 'models/new_model_disc_seq_' + FUNCTION + '.h5' | |||
logging.info('Starting training the model') | |||
train_generator = DataGenerator(batch_size, nb_classes) | |||
train_generator.fit(train_data, train_labels) | |||
valid_generator = DataGenerator(batch_size, nb_classes) | |||
valid_generator.fit(val_data, val_labels) | |||
test_generator = DataGenerator(batch_size, nb_classes) | |||
test_generator.fit(test_data, test_labels) | |||
if is_train: | |||
generator_model, discriminator_model = get_model(params, nb_classes, batch_size) | |||
train_wgan(generator_model, discriminator_model, batch_size=batch_size, epochs=nb_epoch, | |||
x_train=train_data, y_train=train_labels, x_val=val_data, y_val=val_labels, | |||
generator_model_path=generator_model_path, | |||
discriminator_model_path=discriminator_model_path) | |||
logging.info('Loading best model') | |||
model = load_model(generator_model_path, | |||
custom_objects={'generator_recunstruction_loss_new': generator_recunstruction_loss_new, | |||
'wasserstein_loss': wasserstein_loss}) | |||
logging.info('Predicting') | |||
preds = model.predict_generator(test_generator, steps=len(test_data) / batch_size)[0] | |||
# incon = 0 | |||
# for i in xrange(len(test_data)): | |||
# for j in xrange(len(functions)): | |||
# childs = set(go[functions[j]]['children']).intersection(func_set) | |||
# ok = True` | |||
# for n_id in childs: | |||
# if preds[i, j] < preds[i, go_indexes[n_id]]: | |||
# preds[i, j] = preds[i, go_indexes[n_id]] | |||
# ok = False | |||
# if not ok: | |||
# incon += 1 | |||
logging.info('Computing performance') | |||
f, p, r, t, preds_max = compute_performance(preds, test_labels) #, test_gos) | |||
roc_auc = compute_roc(preds, test_labels) | |||
mcc = compute_mcc(preds_max, test_labels) | |||
aupr , _ = compute_aupr(preds, test_labels) | |||
m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max = micro_macro_function_centric_f1(preds.T, test_labels.T) | |||
logging.info('Protein centric macro Th, PR, RC, F1: \t %f %f %f %f' % (t, p, r, f)) | |||
logging.info('ROC AUC: \t %f ' % (roc_auc, )) | |||
logging.info('MCC: \t %f ' % (mcc, )) | |||
logging.info('AUPR: \t %f ' % (aupr, )) | |||
logging.info('Function centric macro PR, RC, F1: \t %f %f %f' % (M_pr_max, M_rc_max, M_f1_max) ) | |||
logging.info('Function centric micro PR, RC, F1: \t %f %f %f' % (m_pr_max, m_rc_max, m_f1_max) ) | |||
function_centric_performance(functions, preds.T, test_labels.T, train_labels.T) | |||
def load_prot_ipro(): | |||
proteins = list() | |||
ipros = list() | |||
with open(DATA_ROOT + 'swissprot_ipro.tab') as f: | |||
for line in f: | |||
it = line.strip().split('\t') | |||
if len(it) != 3: | |||
continue | |||
prot = it[1] | |||
iprs = it[2].split(';') | |||
proteins.append(prot) | |||
ipros.append(iprs) | |||
return pd.DataFrame({'proteins': proteins, 'ipros': ipros}) | |||
def performanc_by_interpro(): | |||
pred_df = pd.read_pickle(DATA_ROOT + 'test-' + FUNCTION + '-preds.pkl') | |||
ipro_df = load_prot_ipro() | |||
df = pred_df.merge(ipro_df, on='proteins', how='left') | |||
ipro = get_ipro() | |||
def reshape(values): | |||
values = np.hstack(values).reshape( | |||
len(values), len(values[0])) | |||
return values | |||
for ipro_id in ipro: | |||
if len(ipro[ipro_id]['parents']) > 0: | |||
continue | |||
labels = list() | |||
predictions = list() | |||
gos = list() | |||
for i, row in df.iterrows(): | |||
if not isinstance(row['ipros'], list): | |||
continue | |||
if ipro_id in row['ipros']: | |||
labels.append(row['labels']) | |||
predictions.append(row['predictions']) | |||
gos.append(row['gos']) | |||
pr = 0 | |||
rc = 0 | |||
total = 0 | |||
p_total = 0 | |||
for i in range(len(labels)): | |||
tp = np.sum(labels[i] * predictions[i]) | |||
fp = np.sum(predictions[i]) - tp | |||
fn = np.sum(labels[i]) - tp | |||
all_gos = set() | |||
for go_id in gos[i]: | |||
if go_id in all_functions: | |||
all_gos |= get_anchestors(go, go_id) | |||
all_gos.discard(GO_ID) | |||
all_gos -= func_set | |||
fn += len(all_gos) | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp != 0: | |||
p_total += 1 | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
pr += precision | |||
rc += recall | |||
if total > 0 and p_total > 0: | |||
rc /= total | |||
pr /= p_total | |||
if pr + rc > 0: | |||
f = 2 * pr * rc / (pr + rc) | |||
logging.info('%s\t%d\t%f\t%f\t%f' % ( | |||
ipro_id, len(labels), f, pr, rc)) | |||
def function_centric_performance(functions, preds, labels, labels_train): | |||
results = [] | |||
preds = np.round(preds, 2) | |||
for i in range(preds.shape[0]): | |||
f_max = 0 | |||
p_max = 0 | |||
r_max = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds[i, :] > threshold).astype(np.int32) | |||
tp = np.sum(predictions * labels[i, :]) | |||
fp = np.sum(predictions) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
if tp > 0: | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
f = 2 * precision * recall / (precision + recall) | |||
else: | |||
if fp == 0 and fn == 0: | |||
precision = 1 | |||
recall = 1 | |||
f = 1 | |||
else: | |||
precision = 0 | |||
recall = 0 | |||
f = 0 | |||
if f_max < f: | |||
f_max = f | |||
p_max = precision | |||
r_max = recall | |||
num_prots_train = np.sum(labels_train[i, :]) | |||
height = get_height(go, functions[i]) | |||
results.append([functions[i], num_prots_train, height, f_max, p_max, r_max]) | |||
results = pd.DataFrame(results) | |||
results.to_csv('Con_GodGanSeq_results_' + FUNCTION + '.txt', sep='\t', index=False) | |||
def function_centric_performance_backup(functions, preds, labels, labels_train): | |||
results = [] | |||
preds = np.round(preds, 2) | |||
for i in range(len(functions)): | |||
f_max = 0 | |||
p_max = 0 | |||
r_max = 0 | |||
x = list() | |||
y = list() | |||
total = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds[i, :] > threshold).astype(np.int32) | |||
tp = np.sum(predictions * labels[i, :]) | |||
fp = np.sum(predictions) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
if tp >0: | |||
sn = tp / (1.0 * np.sum(labels[i, :])) | |||
sp = np.sum((predictions ^ 1) * (labels[i, :] ^ 1)) | |||
sp /= 1.0 * np.sum(labels[i, :] ^ 1) | |||
fpr = 1 - sp | |||
x.append(fpr) | |||
y.append(sn) | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
f = 2 * precision * recall / (precision + recall) | |||
total +=1 | |||
if f_max < f: | |||
f_max = f | |||
p_max = precision | |||
r_max = recall | |||
num_prots = np.sum(labels[i, :]) | |||
num_prots_train = np.sum(labels_train[i,:]) | |||
if total >1 : | |||
roc_auc = auc(x, y) | |||
else: | |||
roc_auc =0 | |||
height = get_height(go , functions[i]) | |||
results.append([functions[i], f_max, p_max, r_max, num_prots, num_prots_train, height,roc_auc]) | |||
results = pd.DataFrame(results) | |||
#results.to_csv('new_results.txt' , sep='\t' , index = False) | |||
results.to_csv('Con_GodGanSeq_results_'+FUNCTION +'.txt', sep='\t', index=False) | |||
#results = np.array(results) | |||
#p_mean = (np.sum(results[:,2])) / len(functions) | |||
#r_mean = (np.sum(results[:,3])) / len(functions) | |||
#f_mean = (2*p_mean*r_mean)/(p_mean+r_mean) | |||
#roc_auc_mean = (np.sum(results[:,7])) / len(functions) | |||
#print('Function centric performance (macro) ' '%f %f %f %f' % (f_mean, p_mean, r_mean, roc_auc_mean)) | |||
def micro_macro_function_centric_f1_backup(preds, labels): | |||
preds = np.round(preds, 2) | |||
m_f1_max = 0 | |||
M_f1_max = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds > threshold).astype(np.int32) | |||
m_tp = 0 | |||
m_fp = 0 | |||
m_fn = 0 | |||
M_pr = 0 | |||
M_rc = 0 | |||
total = 0 | |||
p_total = 0 | |||
for i in range(len(preds)): | |||
tp = np.sum(predictions[i, :] * labels[i, :]) | |||
fp = np.sum(predictions[i, :]) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp > 0: | |||
pr = tp / (1.0 * (tp + fp)) | |||
rc = tp / (1.0 * (tp + fn)) | |||
m_tp += tp | |||
m_fp += fp | |||
m_fn += fn | |||
M_pr += pr | |||
M_rc += rc | |||
p_total += 1 | |||
if p_total == 0: | |||
continue | |||
if total > 0: | |||
m_tp /= total | |||
m_fn /= total | |||
m_fp /= total | |||
m_pr = m_tp / (1.0 * (m_tp + m_fp)) | |||
m_rc = m_tp / (1.0 * (m_tp + m_fn)) | |||
M_pr /= p_total | |||
M_rc /= total | |||
m_f1 = 2 * m_pr * m_rc / (m_pr + m_rc) | |||
M_f1 = 2 * M_pr * M_rc / (M_pr + M_rc) | |||
if m_f1 > m_f1_max: | |||
m_f1_max = m_f1 | |||
m_pr_max = m_pr | |||
m_rc_max = m_rc | |||
if M_f1 > M_f1_max: | |||
M_f1_max = M_f1 | |||
M_pr_max = M_pr | |||
M_rc_max = M_rc | |||
return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max | |||
def micro_macro_function_centric_f1(preds, labels): | |||
preds = np.round(preds, 2) | |||
m_f1_max = 0 | |||
M_f1_max = 0 | |||
for t in range(1, 200): | |||
threshold = t / 200.0 | |||
predictions = (preds > threshold).astype(np.int32) | |||
m_tp = 0 | |||
m_fp = 0 | |||
m_fn = 0 | |||
M_pr = 0 | |||
M_rc = 0 | |||
for i in range(preds.shape[0]): | |||
tp = np.sum(predictions[i, :] * labels[i, :]) | |||
fp = np.sum(predictions[i, :]) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
m_tp += tp | |||
m_fp += fp | |||
m_fn += fn | |||
if tp > 0: | |||
pr = 1.0 * tp / (1.0 * (tp + fp)) | |||
rc = 1.0 * tp / (1.0 * (tp + fn)) | |||
else: | |||
if fp == 0 and fn == 0: | |||
pr = 1 | |||
rc = 1 | |||
else: | |||
pr = 0 | |||
rc = 0 | |||
M_pr += pr | |||
M_rc += rc | |||
if m_tp > 0: | |||
m_pr = 1.0 * m_tp / (1.0 * (m_tp + m_fp)) | |||
m_rc = 1.0 * m_tp / (1.0 * (m_tp + m_fn)) | |||
m_f1 = 2.0 * m_pr * m_rc / (m_pr + m_rc) | |||
else: | |||
if m_fp == 0 and m_fn == 0: | |||
m_pr = 1 | |||
m_rc = 1 | |||
m_f1 = 1 | |||
else: | |||
m_pr = 0 | |||
m_rc = 0 | |||
m_f1 = 0 | |||
M_pr /= preds.shape[0] | |||
M_rc /= preds.shape[0] | |||
if M_pr == 0 and M_rc == 0: | |||
M_f1 = 0 | |||
else: | |||
M_f1 = 2.0 * M_pr * M_rc / (M_pr + M_rc) | |||
if m_f1 > m_f1_max: | |||
m_f1_max = m_f1 | |||
m_pr_max = m_pr | |||
m_rc_max = m_rc | |||
if M_f1 > M_f1_max: | |||
M_f1_max = M_f1 | |||
M_pr_max = M_pr | |||
M_rc_max = M_rc | |||
return m_pr_max, m_rc_max, m_f1_max, M_pr_max, M_rc_max, M_f1_max | |||
def compute_roc(preds, labels): | |||
# Compute ROC curve and ROC area for each class | |||
fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten()) | |||
roc_auc = auc(fpr, tpr) | |||
return roc_auc | |||
def compute_aupr(preds, labels): | |||
# Compute ROC curve and ROC area for each class | |||
pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten()) | |||
pr_auc = auc(rc, pr) | |||
#pr, rc, threshold =precision_recall_curve(labels.flatten(), preds.flatten(),average ='macro' ) | |||
M_pr_auc = 0 | |||
return pr_auc, M_pr_auc | |||
def compute_mcc(preds, labels): | |||
# Compute ROC curve and ROC area for each class | |||
mcc = matthews_corrcoef(labels.flatten(), preds.flatten()) | |||
return mcc | |||
def compute_performance(preds, labels): #, gos): | |||
preds = np.round(preds, 2) | |||
f_max = 0 | |||
p_max = 0 | |||
r_max = 0 | |||
t_max = 0 | |||
for t in range(1, 100): | |||
threshold = t / 100.0 | |||
predictions = (preds > threshold).astype(np.int32) | |||
total = 0 | |||
f = 0.0 | |||
p = 0.0 | |||
r = 0.0 | |||
p_total = 0 | |||
for i in range(labels.shape[0]): | |||
tp = np.sum(predictions[i, :] * labels[i, :]) | |||
fp = np.sum(predictions[i, :]) - tp | |||
fn = np.sum(labels[i, :]) - tp | |||
all_gos = set() | |||
#for go_id in gos[i]: | |||
# if go_id in all_functions: | |||
# all_gos |= get_anchestors(go, go_id) | |||
#all_gos.discard(GO_ID) | |||
#all_gos -= func_set | |||
#fn += len(all_gos) | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp != 0: | |||
p_total += 1 | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
p += precision | |||
r += recall | |||
if p_total == 0: | |||
continue | |||
r /= total | |||
p /= p_total | |||
if p + r > 0: | |||
f = 2 * p * r / (p + r) | |||
if f_max < f: | |||
f_max = f | |||
p_max = p | |||
r_max = r | |||
t_max = threshold | |||
predictions_max = predictions | |||
return f_max, p_max, r_max, t_max, predictions_max | |||
def get_gos(pred): | |||
mdist = 1.0 | |||
mgos = None | |||
for i in range(len(labels_gos)): | |||
labels, gos = labels_gos[i] | |||
dist = distance.cosine(pred, labels) | |||
if mdist > dist: | |||
mdist = dist | |||
mgos = gos | |||
return mgos | |||
def compute_similarity_performance(train_df, test_df, preds): | |||
logging.info("Computing similarity performance") | |||
logging.info("Training data size %d" % len(train_df)) | |||
train_labels = train_df['labels'].values | |||
train_gos = train_df['gos'].values | |||
global labels_gos | |||
labels_gos = zip(train_labels, train_gos) | |||
p = Pool(64) | |||
pred_gos = p.map(get_gos, preds) | |||
total = 0 | |||
p = 0.0 | |||
r = 0.0 | |||
f = 0.0 | |||
test_gos = test_df['gos'].values | |||
for gos, tgos in zip(pred_gos, test_gos): | |||
preds = set() | |||
test = set() | |||
for go_id in gos: | |||
if go_id in all_functions: | |||
preds |= get_anchestors(go, go_id) | |||
for go_id in tgos: | |||
if go_id in all_functions: | |||
test |= get_anchestors(go, go_id) | |||
tp = len(preds.intersection(test)) | |||
fp = len(preds - test) | |||
fn = len(test - preds) | |||
if tp == 0 and fp == 0 and fn == 0: | |||
continue | |||
total += 1 | |||
if tp != 0: | |||
precision = tp / (1.0 * (tp + fp)) | |||
recall = tp / (1.0 * (tp + fn)) | |||
p += precision | |||
r += recall | |||
f += 2 * precision * recall / (precision + recall) | |||
return f / total, p / total, r / total | |||
def print_report(report, go_id): | |||
with open(DATA_ROOT + 'reports.txt', 'a') as f: | |||
f.write('Classification report for ' + go_id + '\n') | |||
f.write(report + '\n') | |||
if __name__ == '__main__': | |||
main() |
@@ -1,182 +0,0 @@ | |||
from functools import partial | |||
import numpy as np | |||
from keras import backend as K | |||
from keras.layers import Input | |||
from keras.layers.merge import _Merge | |||
from keras.models import Model | |||
from keras.optimizers import Adam | |||
from keras.losses import binary_crossentropy | |||
def wasserstein_loss(y_true, y_pred): | |||
"""Calculates the Wasserstein loss for a sample batch. | |||
The Wasserstein loss function is very simple to calculate. In a standard GAN, the discriminator | |||
has a sigmoid output, representing the probability that samples are real or generated. In Wasserstein | |||
GANs, however, the output is linear with no activation function! Instead of being constrained to [0, 1], | |||
the discriminator wants to make the distance between its output for real and generated samples as large as possible. | |||
The most natural way to achieve this is to label generated samples -1 and real samples 1, instead of the | |||
0 and 1 used in normal GANs, so that multiplying the outputs by the labels will give you the loss immediately. | |||
Note that the nature of this loss means that it can be (and frequently will be) less than 0.""" | |||
return K.mean(y_true * y_pred) | |||
def generator_recunstruction_loss(y_true, y_pred, enableTrain): | |||
return binary_crossentropy(y_true, y_pred) * enableTrain | |||
global enable_train | |||
enable_train = Input(shape = (1,)) | |||
global generator_recunstruction_loss_new | |||
generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain = enable_train) | |||
generator_recunstruction_loss_new.__name__ = 'generator_recunstruction_loss_new' | |||
def WGAN_wrapper(generator, discriminator, generator_input_shape, discriminator_input_shape, discriminator_input_shape2, | |||
batch_size, gradient_penalty_weight): | |||
BATCH_SIZE = batch_size | |||
GRADIENT_PENALTY_WEIGHT = gradient_penalty_weight | |||
def set_trainable_state(model, state): | |||
for layer in model.layers: | |||
layer.trainable = state | |||
model.trainable = state | |||
def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight): | |||
"""Calculates the gradient penalty loss for a batch of "averaged" samples. | |||
In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the loss function | |||
that penalizes the network if the gradient norm moves away from 1. However, it is impossible to evaluate | |||
this function at all points in the input space. The compromise used in the paper is to choose random points | |||
on the lines between real and generated samples, and check the gradients at these points. Note that it is the | |||
gradient w.r.t. the input averaged samples, not the weights of the discriminator, that we're penalizing! | |||
In order to evaluate the gradients, we must first run samples through the generator and evaluate the loss. | |||
Then we get the gradients of the discriminator w.r.t. the input averaged samples. | |||
The l2 norm and penalty can then be calculated for this gradient. | |||
Note that this loss function requires the original averaged samples as input, but Keras only supports passing | |||
y_true and y_pred to loss functions. To get around this, we make a partial() of the function with the | |||
averaged_samples argument, and use that for model training.""" | |||
# first get the gradients: | |||
# assuming: - that y_pred has dimensions (batch_size, 1) | |||
# - averaged_samples has dimensions (batch_size, nbr_features) | |||
# gradients afterwards has dimension (batch_size, nbr_features), basically | |||
# a list of nbr_features-dimensional gradient vectors | |||
gradients = K.gradients(y_pred, averaged_samples)[0] | |||
# compute the euclidean norm by squaring ... | |||
gradients_sqr = K.square(gradients) | |||
# ... summing over the rows ... | |||
gradients_sqr_sum = K.sum(gradients_sqr, | |||
axis=np.arange(1, len(gradients_sqr.shape))) | |||
# ... and sqrt | |||
gradient_l2_norm = K.sqrt(gradients_sqr_sum) | |||
# compute lambda * (1 - ||grad||)^2 still for each single sample | |||
gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm) | |||
# return the mean as loss over all the batch samples | |||
return K.mean(gradient_penalty) | |||
class RandomWeightedAverage(_Merge): | |||
"""Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line | |||
between each pair of input points. | |||
Inheriting from _Merge is a little messy but it was the quickest solution I could think of. | |||
Improvements appreciated.""" | |||
def _merge_function(self, inputs): | |||
weights = K.random_uniform((BATCH_SIZE, 1)) | |||
print(inputs[0]) | |||
return (weights * inputs[0]) + ((1 - weights) * inputs[1]) | |||
# The generator_model is used when we want to train the generator layers. | |||
# As such, we ensure that the discriminator layers are not trainable. | |||
# Note that once we compile this model, updating .trainable will have no effect within it. As such, it | |||
# won't cause problems if we later set discriminator.trainable = True for the discriminator_model, as long | |||
# as we compile the generator_model first. | |||
set_trainable_state(discriminator, False) | |||
set_trainable_state(generator, True) | |||
#enable_train = Input(shape = (1,)) | |||
generator_input = Input(shape=generator_input_shape) | |||
generator_layers = generator(generator_input) | |||
#discriminator_noise_in = Input(shape=(1,)) | |||
#input_seq_g = Input(shape = discriminator_input_shape2) | |||
discriminator_layers_for_generator = discriminator([generator_layers, generator_input]) | |||
generator_model = Model(inputs=[generator_input, enable_train], | |||
outputs=[generator_layers, discriminator_layers_for_generator]) | |||
# We use the Adam paramaters from Gulrajani et al. | |||
#global generator_recunstruction_loss_new | |||
#generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain = enable_train) | |||
#generator_recunstruction_loss_new.__name__ = 'generator_RLN' | |||
loss = [generator_recunstruction_loss_new, wasserstein_loss] | |||
loss_weights = [30, 1] | |||
generator_model.compile(optimizer=Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08), | |||
loss=loss, loss_weights=loss_weights) | |||
# Now that the generator_model is compiled, we can make the discriminator layers trainable. | |||
set_trainable_state(discriminator, True) | |||
set_trainable_state(generator, False) | |||
# The discriminator_model is more complex. It takes both real image samples and random noise seeds as input. | |||
# The noise seed is run through the generator model to get generated images. Both real and generated images | |||
# are then run through the discriminator. Although we could concatenate the real and generated images into a | |||
# single tensor, we don't (see model compilation for why). | |||
real_samples = Input(shape=discriminator_input_shape) | |||
input_seq = Input(shape = discriminator_input_shape2) | |||
generator_input_for_discriminator = Input(shape=generator_input_shape) | |||
generated_samples_for_discriminator = generator(generator_input_for_discriminator) | |||
discriminator_output_from_generator = discriminator([generated_samples_for_discriminator, generator_input_for_discriminator] ) | |||
discriminator_output_from_real_samples = discriminator([real_samples, input_seq]) | |||
# We also need to generate weighted-averages of real and generated samples, to use for the gradient norm penalty. | |||
averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_discriminator]) | |||
average_seq = RandomWeightedAverage()([input_seq, generator_input_for_discriminator]) | |||
# We then run these samples through the discriminator as well. Note that we never really use the discriminator | |||
# output for these samples - we're only running them to get the gradient norm for the gradient penalty loss. | |||
#print('hehehe') | |||
#print(averaged_samples) | |||
averaged_samples_out = discriminator([averaged_samples, average_seq] ) | |||
# The gradient penalty loss function requires the input averaged samples to get gradients. However, | |||
# Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial() | |||
# of the function with the averaged samples here. | |||
partial_gp_loss = partial(gradient_penalty_loss, | |||
averaged_samples=averaged_samples, | |||
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) | |||
partial_gp_loss.__name__ = 'gradient_penalty' # Functions need names or Keras will throw an error | |||
# Keras requires that inputs and outputs have the same number of samples. This is why we didn't concatenate the | |||
# real samples and generated samples before passing them to the discriminator: If we had, it would create an | |||
# output with 2 * BATCH_SIZE samples, while the output of the "averaged" samples for gradient penalty | |||
# would have only BATCH_SIZE samples. | |||
# If we don't concatenate the real and generated samples, however, we get three outputs: One of the generated | |||
# samples, one of the real samples, and one of the averaged samples, all of size BATCH_SIZE. This works neatly! | |||
discriminator_model = Model(inputs=[real_samples, generator_input_for_discriminator,input_seq], | |||
outputs=[discriminator_output_from_real_samples, | |||
discriminator_output_from_generator, | |||
averaged_samples_out]) | |||
# We use the Adam paramaters from Gulrajani et al. We use the Wasserstein loss for both the real and generated | |||
# samples, and the gradient penalty loss for the averaged samples. | |||
discriminator_model.compile(optimizer=Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08), | |||
loss=[wasserstein_loss, | |||
wasserstein_loss, | |||
partial_gp_loss]) | |||
# set_trainable_state(discriminator, True) | |||
# set_trainable_state(generator, True) | |||
return generator_model, discriminator_model |