123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- ###
- '''
- Similar to M1 from https://github.com/dpkingma/nips14-ssl
- Original Author: S. Saemundsson
- Edited by: Z. Rahaie
- '''
- ###
-
- from genclass import GenerativeClassifier
- from vae import VariationalAutoencoder
- import numpy as np
- import data.asd as asd
-
- def encode_dataset( model_path, min_std = 0.0 ):
-
- VAE = VariationalAutoencoder( dim_x, dim_z )
- with VAE.session:
- VAE.saver.restore( VAE.session, VAE_model_path )
-
- enc_x_lab_mean, enc_x_lab_var = VAE.encode( x_lab )
- enc_x_ulab_mean, enc_x_ulab_var = VAE.encode( x_ulab )
- enc_x_valid_mean, enc_x_valid_var = VAE.encode( x_valid )
- enc_x_test_mean, enc_x_test_var = VAE.encode( x_test )
-
- id_x_keep = np.std( enc_x_ulab_mean, axis = 0 ) > min_std
-
- enc_x_lab_mean, enc_x_lab_var = enc_x_lab_mean[ :, id_x_keep ], enc_x_lab_var[ :, id_x_keep ]
- enc_x_ulab_mean, enc_x_ulab_var = enc_x_ulab_mean[ :, id_x_keep ], enc_x_ulab_var[ :, id_x_keep ]
- enc_x_valid_mean, enc_x_valid_var = enc_x_valid_mean[ :, id_x_keep ], enc_x_valid_var[ :, id_x_keep ]
- enc_x_test_mean, enc_x_test_var = enc_x_test_mean[ :, id_x_keep ], enc_x_test_var[ :, id_x_keep ]
-
- data_lab = np.hstack( [ enc_x_lab_mean, enc_x_lab_var ] )
- data_ulab = np.hstack( [ enc_x_ulab_mean, enc_x_ulab_var ] )
- data_valid = np.hstack( [enc_x_valid_mean, enc_x_valid_var] )
- data_test = np.hstack( [enc_x_test_mean, enc_x_test_var] )
-
- return data_lab, data_ulab, data_valid, data_test
-
- if __name__ == '__main__':
-
-
- num_batches = 100 #Number of minibatches in a single epoch
- epochs = 1001 #Number of epochs through the full dataset
- learning_rate = 1e-4 #Learning rate of ADAM
- alpha = 0.1 #Discriminatory factor (see equation (9) of http://arxiv.org/pdf/1406.5298v2.pdf)
- seed = 12345 #Seed for RNG
-
-
- ####################
- ''' Load Dataset '''
- ####################
-
- asd_path = ['nds/asd_case.txt', 'nds/asd_control.txt']
- #Uses anglpy module from original paper (linked at top) to split the dataset for semi-supervised training
- train_x, train_y, valid_x, valid_y, test_x, test_y = mnist.load_numpy_split(mnist_path, binarize_y=True)
- x_l, y_l, x_u, y_u = mnist.create_semisupervised(train_x, train_y, num_lab)
-
- x_lab, y_lab = x_l.T, y_l.T
- x_ulab, y_ulab = x_u.T, y_u.T
- x_valid, y_valid = valid_x.T, valid_y.T
- x_test, y_test = test_x.T, test_y.T
-
- ################
- ''' Load VAE '''
- ################
-
- VAE_model_path = 'temp/mid_training_0.cpkt'
- min_std = 0.1 #Dimensions with std < min_std are removed before training with GC
-
- data_lab, data_ulab, data_valid, data_test = encode_dataset( VAE_model_path, min_std )
-
- dim_x = data_lab.shape[1] / 2
- dim_y = y_lab.shape[1]
- num_examples = data_lab.shape[0] + data_ulab.shape[0]
-
- ###################################
- ''' Train Generative Classifier '''
- ###################################
-
- GC = GenerativeClassifier( dim_x, dim_z, dim_y,
- num_examples, num_lab, num_batches,
- hidden_layers_px = hidden_layers_px,
- hidden_layers_qz = hidden_layers_qz,
- hidden_layers_qy = hidden_layers_qy,
- alpha = alpha )
-
- GC.train( x_labelled = data_lab, y = y_lab, x_unlabelled = data_ulab,
- x_valid = data_valid, y_valid = y_valid,
- epochs = epochs,
- learning_rate = learning_rate,
- seed = seed,
- print_every = 10,
- load_path = None )
-
-
- ############################
- ''' Evaluate on Test Set '''
- ############################
-
- GC_eval = GenerativeClassifier( dim_x, dim_z, dim_y, num_examples, num_lab, num_batches )
-
- with GC_eval.session:
- GC_eval.saver.restore( GC_eval.session, GC.save_path )
- GC_eval.predict_labels( data_test, y_test )
|