You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 1.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. ###
  2. '''
  3. Similar to M1 from https://github.com/dpkingma/nips14-ssl
  4. Original Author: S. Saemundsson
  5. Edited by: Z. Rahaie
  6. '''
  7. ###
  8. import prettytensor as pt
  9. import tensorflow as tf
  10. import numpy as np
  11. logc = np.log(2.*np.pi)
  12. c = - 0.5 * np.log(2*np.pi)
  13. def tf_normal_logpdf(x, mu, log_sigma_sq):
  14. return ( - 0.5 * logc - log_sigma_sq / 2. - tf.div( tf.square( tf.sub( x, mu ) ), 2 * tf.exp( log_sigma_sq ) ) )
  15. def tf_stdnormal_logpdf(x):
  16. return ( - 0.5 * ( logc + tf.square( x ) ) )
  17. def tf_gaussian_ent(log_sigma_sq):
  18. return ( - 0.5 * ( logc + 1.0 + log_sigma_sq ) )
  19. def tf_gaussian_marg(mu, log_sigma_sq):
  20. return ( - 0.5 * ( logc + ( tf.square( mu ) + tf.exp( log_sigma_sq ) ) ) )
  21. def tf_binary_xentropy(x, y, const = 1e-10):
  22. return - ( x * tf.log ( tf.clip_by_value( y, const, 1.0 ) ) + \
  23. (1.0 - x) * tf.log( tf.clip_by_value( 1.0 - y, const, 1.0 ) ) )
  24. def feed_numpy_semisupervised(num_lab_batch, num_ulab_batch, x_lab, y, x_ulab):
  25. size = x_lab.shape[0] + x_ulab.shape[0]
  26. batch_size = num_lab_batch + num_ulab_batch
  27. count = int(size / batch_size)
  28. dim = x_lab.shape[1]
  29. for i in xrange(count):
  30. start_lab = i * num_lab_batch
  31. end_lab = start_lab + num_lab_batch
  32. start_ulab = i * num_ulab_batch
  33. end_ulab = start_ulab + num_ulab_batch
  34. yield [ x_lab[start_lab:end_lab,:dim/2], x_lab[start_lab:end_lab,dim/2:dim], y[start_lab:end_lab],
  35. x_ulab[start_ulab:end_ulab,:dim/2], x_ulab[start_ulab:end_ulab,dim/2:dim] ]
  36. def feed_numpy(batch_size, x):
  37. size = x.shape[0]
  38. count = int(size / batch_size)
  39. dim = x.shape[1]
  40. for i in xrange(count):
  41. start = i * batch_size
  42. end = start + batch_size
  43. yield x[start:end]