You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

conditional_wgan_wrapper_post.py 10.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from functools import partial
  2. import numpy as np
  3. from keras import backend as K
  4. from keras.layers import Input
  5. from keras.layers.merge import _Merge
  6. from keras.models import Model
  7. from keras.optimizers import Adam
  8. from keras.losses import binary_crossentropy
  9. def wasserstein_loss(y_true, y_pred):
  10. """Calculates the Wasserstein loss for a sample batch.
  11. The Wasserstein loss function is very simple to calculate. In a standard GAN, the discriminator
  12. has a sigmoid output, representing the probability that samples are real or generated. In Wasserstein
  13. GANs, however, the output is linear with no activation function! Instead of being constrained to [0, 1],
  14. the discriminator wants to make the distance between its output for real and generated samples as large as possible.
  15. The most natural way to achieve this is to label generated samples -1 and real samples 1, instead of the
  16. 0 and 1 used in normal GANs, so that multiplying the outputs by the labels will give you the loss immediately.
  17. Note that the nature of this loss means that it can be (and frequently will be) less than 0."""
  18. return K.mean(y_true * y_pred)
  19. def generator_recunstruction_loss(y_true, y_pred, enableTrain):
  20. return binary_crossentropy(y_true, y_pred) * enableTrain
  21. global enable_train
  22. enable_train = Input(shape = (1,))
  23. global generator_recunstruction_loss_new
  24. generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain = enable_train)
  25. generator_recunstruction_loss_new.__name__ = 'generator_recunstruction_loss_new'
  26. def WGAN_wrapper(generator, discriminator, generator_input_shape, discriminator_input_shape, discriminator_input_shape2,
  27. batch_size, gradient_penalty_weight):
  28. BATCH_SIZE = batch_size
  29. GRADIENT_PENALTY_WEIGHT = gradient_penalty_weight
  30. def set_trainable_state(model, state):
  31. for layer in model.layers:
  32. layer.trainable = state
  33. model.trainable = state
  34. def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
  35. """Calculates the gradient penalty loss for a batch of "averaged" samples.
  36. In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the loss function
  37. that penalizes the network if the gradient norm moves away from 1. However, it is impossible to evaluate
  38. this function at all points in the input space. The compromise used in the paper is to choose random points
  39. on the lines between real and generated samples, and check the gradients at these points. Note that it is the
  40. gradient w.r.t. the input averaged samples, not the weights of the discriminator, that we're penalizing!
  41. In order to evaluate the gradients, we must first run samples through the generator and evaluate the loss.
  42. Then we get the gradients of the discriminator w.r.t. the input averaged samples.
  43. The l2 norm and penalty can then be calculated for this gradient.
  44. Note that this loss function requires the original averaged samples as input, but Keras only supports passing
  45. y_true and y_pred to loss functions. To get around this, we make a partial() of the function with the
  46. averaged_samples argument, and use that for model training."""
  47. # first get the gradients:
  48. # assuming: - that y_pred has dimensions (batch_size, 1)
  49. # - averaged_samples has dimensions (batch_size, nbr_features)
  50. # gradients afterwards has dimension (batch_size, nbr_features), basically
  51. # a list of nbr_features-dimensional gradient vectors
  52. gradients = K.gradients(y_pred, averaged_samples)[0]
  53. # compute the euclidean norm by squaring ...
  54. gradients_sqr = K.square(gradients)
  55. # ... summing over the rows ...
  56. gradients_sqr_sum = K.sum(gradients_sqr,
  57. axis=np.arange(1, len(gradients_sqr.shape)))
  58. # ... and sqrt
  59. gradient_l2_norm = K.sqrt(gradients_sqr_sum)
  60. # compute lambda * (1 - ||grad||)^2 still for each single sample
  61. gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
  62. # return the mean as loss over all the batch samples
  63. return K.mean(gradient_penalty)
  64. class RandomWeightedAverage(_Merge):
  65. """Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
  66. between each pair of input points.
  67. Inheriting from _Merge is a little messy but it was the quickest solution I could think of.
  68. Improvements appreciated."""
  69. def _merge_function(self, inputs):
  70. weights = K.random_uniform((BATCH_SIZE, 1))
  71. print(inputs[0])
  72. return (weights * inputs[0]) + ((1 - weights) * inputs[1])
  73. # The generator_model is used when we want to train the generator layers.
  74. # As such, we ensure that the discriminator layers are not trainable.
  75. # Note that once we compile this model, updating .trainable will have no effect within it. As such, it
  76. # won't cause problems if we later set discriminator.trainable = True for the discriminator_model, as long
  77. # as we compile the generator_model first.
  78. set_trainable_state(discriminator, False)
  79. set_trainable_state(generator, True)
  80. #enable_train = Input(shape = (1,))
  81. generator_input = Input(shape=generator_input_shape)
  82. generator_layers = generator(generator_input)
  83. #discriminator_noise_in = Input(shape=(1,))
  84. #input_seq_g = Input(shape = discriminator_input_shape2)
  85. discriminator_layers_for_generator = discriminator([generator_layers, generator_input])
  86. generator_model = Model(inputs=[generator_input, enable_train],
  87. outputs=[generator_layers, discriminator_layers_for_generator])
  88. # We use the Adam paramaters from Gulrajani et al.
  89. #global generator_recunstruction_loss_new
  90. #generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain = enable_train)
  91. #generator_recunstruction_loss_new.__name__ = 'generator_RLN'
  92. loss = [generator_recunstruction_loss_new, wasserstein_loss]
  93. loss_weights = [30, 1]
  94. generator_model.compile(optimizer=Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
  95. loss=loss, loss_weights=loss_weights)
  96. # Now that the generator_model is compiled, we can make the discriminator layers trainable.
  97. set_trainable_state(discriminator, True)
  98. set_trainable_state(generator, False)
  99. # The discriminator_model is more complex. It takes both real image samples and random noise seeds as input.
  100. # The noise seed is run through the generator model to get generated images. Both real and generated images
  101. # are then run through the discriminator. Although we could concatenate the real and generated images into a
  102. # single tensor, we don't (see model compilation for why).
  103. real_samples = Input(shape=discriminator_input_shape)
  104. input_seq = Input(shape = discriminator_input_shape2)
  105. generator_input_for_discriminator = Input(shape=generator_input_shape)
  106. generated_samples_for_discriminator = generator(generator_input_for_discriminator)
  107. discriminator_output_from_generator = discriminator([generated_samples_for_discriminator, generator_input_for_discriminator] )
  108. discriminator_output_from_real_samples = discriminator([real_samples, input_seq])
  109. # We also need to generate weighted-averages of real and generated samples, to use for the gradient norm penalty.
  110. averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_discriminator])
  111. average_seq = RandomWeightedAverage()([input_seq, generator_input_for_discriminator])
  112. # We then run these samples through the discriminator as well. Note that we never really use the discriminator
  113. # output for these samples - we're only running them to get the gradient norm for the gradient penalty loss.
  114. #print('hehehe')
  115. #print(averaged_samples)
  116. averaged_samples_out = discriminator([averaged_samples, average_seq] )
  117. # The gradient penalty loss function requires the input averaged samples to get gradients. However,
  118. # Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
  119. # of the function with the averaged samples here.
  120. partial_gp_loss = partial(gradient_penalty_loss,
  121. averaged_samples=averaged_samples,
  122. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
  123. partial_gp_loss.__name__ = 'gradient_penalty' # Functions need names or Keras will throw an error
  124. # Keras requires that inputs and outputs have the same number of samples. This is why we didn't concatenate the
  125. # real samples and generated samples before passing them to the discriminator: If we had, it would create an
  126. # output with 2 * BATCH_SIZE samples, while the output of the "averaged" samples for gradient penalty
  127. # would have only BATCH_SIZE samples.
  128. # If we don't concatenate the real and generated samples, however, we get three outputs: One of the generated
  129. # samples, one of the real samples, and one of the averaged samples, all of size BATCH_SIZE. This works neatly!
  130. discriminator_model = Model(inputs=[real_samples, generator_input_for_discriminator,input_seq],
  131. outputs=[discriminator_output_from_real_samples,
  132. discriminator_output_from_generator,
  133. averaged_samples_out])
  134. # We use the Adam paramaters from Gulrajani et al. We use the Wasserstein loss for both the real and generated
  135. # samples, and the gradient penalty loss for the averaged samples.
  136. discriminator_model.compile(optimizer=Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
  137. loss=[wasserstein_loss,
  138. wasserstein_loss,
  139. partial_gp_loss])
  140. # set_trainable_state(discriminator, True)
  141. # set_trainable_state(generator, True)
  142. return generator_model, discriminator_model