You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ssl_vae.py 1.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. ###
  2. '''
  3. Similar to M1 from https://github.com/dpkingma/nips14-ssl
  4. Original Author: S. Saemundsson
  5. Edited by: Z. Rahaie
  6. '''
  7. ###
  8. from vae import VariationalAutoencoder
  9. import numpy as np
  10. import data.asd as asd
  11. if __name__ == '__main__':
  12. num_batches = 0
  13. dim_z = 0
  14. epochs = 0
  15. learning_rate = 1e-4
  16. l2_loss = 1e-5
  17. seed = 12345
  18. hidden_layers_px = [ sqrt(input_size), sqrt(hidden_layers_px[0]), sqrt(hidden_layers_px[1])]
  19. hidden_layers_qz = [ sqrt(input_size), sqrt(hidden_layers_px[0]), sqrt(hidden_layers_px[2]) ]
  20. asd_path = ['nds/AutDB_ASD_cnv_dataset.txt']
  21. #Uses anglpy module exists kingma github (linked before) to divide the dataset
  22. train_x, train_y, valid_x, valid_y, test_x, test_y = asd.load_numpy(asd_path, binarize_y=True)
  23. x_train, y_train = train_x.T, train_y.T
  24. x_valid, y_valid = valid_x.T, valid_y.T
  25. x_test, y_test = test_x.T, test_y.T
  26. dim_x = x_train.shape[1]
  27. dim_y = y_train.shape[1]
  28. VAE = VariationalAutoencoder( dim_x = dim_x,
  29. dim_z = dim_z,
  30. hidden_layers_px = hidden_layers_px,
  31. hidden_layers_qz = hidden_layers_qz,
  32. l2_loss = l2_loss )
  33. #every n iterations (set to 0 to disable)
  34. VAE.train( x = x_train, x_valid = x_valid, epochs = epochs, num_batches = num_batches,
  35. learning_rate = learning_rate, seed = seed, stop_iter = 30, print_every = 10, draw_img = 0 )
  36. weights_as_numpy = VAE.get_weights()
  37. output_weights_to_label = weights_as_numpy[6]
  38. scores = weights_as_numpy[0] * weights_as_numpy[1] * weights_as_numpy [2] * output_weights_to_label
  39. genes_index = [i for i,v in enumerate(scores) if v > 0]
  40. print(genes_index)