You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_classifier.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. ###
  2. '''
  3. Similar to M1 from https://github.com/dpkingma/nips14-ssl
  4. Original Author: S. Saemundsson
  5. Edited by: Z. Rahaie
  6. '''
  7. ###
  8. from genclass import GenerativeClassifier
  9. from vae import VariationalAutoencoder
  10. import numpy as np
  11. import data.asd as asd
  12. def encode_dataset( model_path, min_std = 0.0 ):
  13. VAE = VariationalAutoencoder( dim_x, dim_z )
  14. with VAE.session:
  15. VAE.saver.restore( VAE.session, VAE_model_path )
  16. enc_x_lab_mean, enc_x_lab_var = VAE.encode( x_lab )
  17. enc_x_ulab_mean, enc_x_ulab_var = VAE.encode( x_ulab )
  18. enc_x_valid_mean, enc_x_valid_var = VAE.encode( x_valid )
  19. enc_x_test_mean, enc_x_test_var = VAE.encode( x_test )
  20. id_x_keep = np.std( enc_x_ulab_mean, axis = 0 ) > min_std
  21. enc_x_lab_mean, enc_x_lab_var = enc_x_lab_mean[ :, id_x_keep ], enc_x_lab_var[ :, id_x_keep ]
  22. enc_x_ulab_mean, enc_x_ulab_var = enc_x_ulab_mean[ :, id_x_keep ], enc_x_ulab_var[ :, id_x_keep ]
  23. enc_x_valid_mean, enc_x_valid_var = enc_x_valid_mean[ :, id_x_keep ], enc_x_valid_var[ :, id_x_keep ]
  24. enc_x_test_mean, enc_x_test_var = enc_x_test_mean[ :, id_x_keep ], enc_x_test_var[ :, id_x_keep ]
  25. data_lab = np.hstack( [ enc_x_lab_mean, enc_x_lab_var ] )
  26. data_ulab = np.hstack( [ enc_x_ulab_mean, enc_x_ulab_var ] )
  27. data_valid = np.hstack( [enc_x_valid_mean, enc_x_valid_var] )
  28. data_test = np.hstack( [enc_x_test_mean, enc_x_test_var] )
  29. return data_lab, data_ulab, data_valid, data_test
  30. if __name__ == '__main__':
  31. num_batches = 100 #Number of minibatches in a single epoch
  32. epochs = 1001 #Number of epochs through the full dataset
  33. learning_rate = 1e-4 #Learning rate of ADAM
  34. alpha = 0.1 #Discriminatory factor (see equation (9) of http://arxiv.org/pdf/1406.5298v2.pdf)
  35. seed = 12345 #Seed for RNG
  36. ####################
  37. ''' Load Dataset '''
  38. ####################
  39. asd_path = ['nds/asd_case.txt', 'nds/asd_control.txt']
  40. #Uses anglpy module from original paper (linked at top) to split the dataset for semi-supervised training
  41. train_x, train_y, valid_x, valid_y, test_x, test_y = mnist.load_numpy_split(mnist_path, binarize_y=True)
  42. x_l, y_l, x_u, y_u = mnist.create_semisupervised(train_x, train_y, num_lab)
  43. x_lab, y_lab = x_l.T, y_l.T
  44. x_ulab, y_ulab = x_u.T, y_u.T
  45. x_valid, y_valid = valid_x.T, valid_y.T
  46. x_test, y_test = test_x.T, test_y.T
  47. ################
  48. ''' Load VAE '''
  49. ################
  50. VAE_model_path = 'temp/mid_training_0.cpkt'
  51. min_std = 0.1 #Dimensions with std < min_std are removed before training with GC
  52. data_lab, data_ulab, data_valid, data_test = encode_dataset( VAE_model_path, min_std )
  53. dim_x = data_lab.shape[1] / 2
  54. dim_y = y_lab.shape[1]
  55. num_examples = data_lab.shape[0] + data_ulab.shape[0]
  56. ###################################
  57. ''' Train Generative Classifier '''
  58. ###################################
  59. GC = GenerativeClassifier( dim_x, dim_z, dim_y,
  60. num_examples, num_lab, num_batches,
  61. hidden_layers_px = hidden_layers_px,
  62. hidden_layers_qz = hidden_layers_qz,
  63. hidden_layers_qy = hidden_layers_qy,
  64. alpha = alpha )
  65. GC.train( x_labelled = data_lab, y = y_lab, x_unlabelled = data_ulab,
  66. x_valid = data_valid, y_valid = y_valid,
  67. epochs = epochs,
  68. learning_rate = learning_rate,
  69. seed = seed,
  70. print_every = 10,
  71. load_path = None )
  72. ############################
  73. ''' Evaluate on Test Set '''
  74. ############################
  75. GC_eval = GenerativeClassifier( dim_x, dim_z, dim_y, num_examples, num_lab, num_batches )
  76. with GC_eval.session:
  77. GC_eval.saver.restore( GC_eval.session, GC.save_path )
  78. GC_eval.predict_labels( data_test, y_test )