You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

conditional_wgan_wrapper_exp2.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from functools import partial
  2. import numpy as np
  3. from keras import backend as K
  4. from keras.layers import Input, RepeatVector, Multiply, Lambda, Permute
  5. from keras.layers.merge import _Merge
  6. from keras.models import Model
  7. from keras.optimizers import Adam, SGD
  8. from keras.losses import binary_crossentropy
  9. def wasserstein_loss(y_true, y_pred):
  10. """Calculates the Wasserstein loss for a sample batch.
  11. The Wasserstein loss function is very simple to calculate. In a standard GAN, the discriminator
  12. has a sigmoid output, representing the probability that samples are real or generated. In Wasserstein
  13. GANs, however, the output is linear with no activation function! Instead of being constrained to [0, 1],
  14. the discriminator wants to make the distance between its output for real and generated samples as large as possible.
  15. The most natural way to achieve this is to label generated samples -1 and real samples 1, instead of the
  16. 0 and 1 used in normal GANs, so that multiplying the outputs by the labels will give you the loss immediately.
  17. Note that the nature of this loss means that it can be (and frequently will be) less than 0."""
  18. return K.mean(-y_true * y_pred)
  19. def generator_recunstruction_loss(y_true, y_pred, enableTrain):
  20. def rescale_back_labels(labels):
  21. rescaled_back_labels = (labels / 2) + 0.5
  22. return rescaled_back_labels
  23. return binary_crossentropy(rescale_back_labels(y_true), rescale_back_labels(y_pred)) # * enableTrain
  24. global enable_train
  25. enable_train = Input(shape=(1,))
  26. global generator_recunstruction_loss_new
  27. generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain=enable_train)
  28. generator_recunstruction_loss_new.__name__ = 'generator_recunstruction_loss_new'
  29. def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
  30. """Calculates the gradient penalty loss for a batch of "averaged" samples.
  31. In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the loss function
  32. that penalizes the network if the gradient norm moves away from 1. However, it is impossible to evaluate
  33. this function at all points in the input space. The compromise used in the paper is to choose random points
  34. on the lines between real and generated samples, and check the gradients at these points. Note that it is the
  35. gradient w.r.t. the input averaged samples, not the weights of the discriminator, that we're penalizing!
  36. In order to evaluate the gradients, we must first run samples through the generator and evaluate the loss.
  37. Then we get the gradients of the discriminator w.r.t. the input averaged samples.
  38. The l2 norm and penalty can then be calculated for this gradient.
  39. Note that this loss function requires the original averaged samples as input, but Keras only supports passing
  40. y_true and y_pred to loss functions. To get around this, we make a partial() of the function with the
  41. averaged_samples argument, and use that for model training."""
  42. # first get the gradients:
  43. # assuming: - that y_pred has dimensions (batch_size, 1)
  44. # - averaged_samples has dimensions (batch_size, nbr_features)
  45. # gradients afterwards has dimension (batch_size, nbr_features), basically
  46. # a list of nbr_features-dimensional gradient vectors
  47. gradients = K.gradients(y_pred, averaged_samples)[0]
  48. # compute the euclidean norm by squaring ...
  49. gradients_sqr = K.square(gradients)
  50. # ... summing over the rows ...
  51. gradients_sqr_sum = K.sum(gradients_sqr,
  52. axis=np.arange(1, len(gradients_sqr.shape)))
  53. # ... and sqrt
  54. gradient_l2_norm = K.sqrt(gradients_sqr_sum)
  55. # compute lambda * (1 - ||grad||)^2 still for each single sample
  56. gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
  57. # return the mean as loss over all the batch samples
  58. return K.mean(gradient_penalty)
  59. def WGAN_wrapper(generator, discriminator, generator_input_shape, discriminator_input_shape, discriminator_input_shape2,
  60. batch_size, gradient_penalty_weight):
  61. BATCH_SIZE = batch_size
  62. GRADIENT_PENALTY_WEIGHT = gradient_penalty_weight
  63. def set_trainable_state(model, state):
  64. for layer in model.layers:
  65. layer.trainable = state
  66. model.trainable = state
  67. class RandomWeightedAverage(_Merge):
  68. """Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
  69. between each pair of input points.
  70. Inheriting from _Merge is a little messy but it was the quickest solution I could think of.
  71. Improvements appreciated."""
  72. def _merge_function(self, inputs):
  73. weights = K.random_uniform((BATCH_SIZE, 1))
  74. print(inputs[0])
  75. return (weights * inputs[0]) + ((1 - weights) * inputs[1])
  76. # The generator_model is used when we want to train the generator layers.
  77. # As such, we ensure that the discriminator layers are not trainable.
  78. # Note that once we compile this model, updating .trainable will have no effect within it. As such, it
  79. # won't cause problems if we later set discriminator.trainable = True for the discriminator_model, as long
  80. # as we compile the generator_model first.
  81. set_trainable_state(discriminator, False)
  82. set_trainable_state(generator, True)
  83. # enable_train = Input(shape = (1,))
  84. generator_input = Input(shape=generator_input_shape)
  85. generator_layers = generator(generator_input)
  86. # masked_generator_layers = Lambda(mul)(generator_layers)
  87. # discriminator_noise_in = Input(shape=(1,))
  88. # input_seq_g = Input(shape = discriminator_input_shape2)
  89. discriminator_layers_for_generator= discriminator([generator_layers, generator_input])
  90. generator_model = Model(inputs=[generator_input, enable_train],
  91. outputs=[generator_layers, discriminator_layers_for_generator])
  92. # We use the Adam paramaters from Gulrajani et al.
  93. # global generator_recunstruction_loss_new
  94. # generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain = enable_train)
  95. # generator_recunstruction_loss_new.__name__ = 'generator_RLN'
  96. loss = [generator_recunstruction_loss_new, wasserstein_loss]
  97. loss_weights = [0, 1]
  98. generator_model.compile(optimizer=Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
  99. loss=loss, loss_weights=loss_weights)
  100. # Now that the generator_model is compiled, we can make the discriminator layers trainable.
  101. set_trainable_state(discriminator, True)
  102. set_trainable_state(generator, False)
  103. # The discriminator_model is more complex. It takes both real image samples and random noise seeds as input.
  104. # The noise seed is run through the generator model to get generated images. Both real and generated images
  105. # are then run through the discriminator. Although we could concatenate the real and generated images into a
  106. # single tensor, we don't (see model compilation for why).
  107. real_samples = Input(shape=discriminator_input_shape)
  108. input_seq = Input(shape=discriminator_input_shape2)
  109. generator_input_for_discriminator = Input(shape=generator_input_shape)
  110. generated_samples_for_discriminator = generator(generator_input_for_discriminator)
  111. # masked_real_samples = Lambda(mul)(real_samples)
  112. # masked_generated_samples_for_discriminator= Lambda(mul)(generated_samples_for_discriminator)
  113. discriminator_output_from_generator = discriminator(
  114. [generated_samples_for_discriminator, generator_input_for_discriminator])
  115. discriminator_output_from_real_samples= discriminator([real_samples, input_seq])
  116. # We also need to generate weighted-averages of real and generated samples, to use for the gradient norm penalty.
  117. averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_discriminator])
  118. average_seq = RandomWeightedAverage()([input_seq, generator_input_for_discriminator])
  119. # We then run these samples through the discriminator as well. Note that we never really use the discriminator
  120. # output for these samples - we're only running them to get the gradient norm for the gradient penalty loss.
  121. # print('hehehe')
  122. # print(averaged_samples)
  123. # masked_averaged_samples = Lambda(mul)(averaged_samples)
  124. averaged_samples_out = discriminator([averaged_samples, average_seq])
  125. # The gradient penalty loss function requires the input averaged samples to get gradients. However,
  126. # Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
  127. # of the function with the averaged samples here.
  128. partial_gp_loss = partial(gradient_penalty_loss,
  129. averaged_samples=averaged_samples,
  130. gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
  131. partial_gp_loss.__name__ = 'gradient_penalty' # Functions need names or Keras will throw an error
  132. # Keras requires that inputs and outputs have the same number of samples. This is why we didn't concatenate the
  133. # real samples and generated samples before passing them to the discriminator: If we had, it would create an
  134. # output with 2 * BATCH_SIZE samples, while the output of the "averaged" samples for gradient penalty
  135. # would have only BATCH_SIZE samples.
  136. # If we don't concatenate the real and generated samples, however, we get three outputs: One of the generated
  137. # samples, one of the real samples, and one of the averaged samples, all of size BATCH_SIZE. This works neatly!
  138. discriminator_model = Model(inputs=[real_samples, generator_input_for_discriminator, input_seq],
  139. outputs=[discriminator_output_from_real_samples,
  140. discriminator_output_from_generator,
  141. averaged_samples_out])
  142. loss_weights2 = [1, 1, 1]
  143. # We use the Adam paramaters from Gulrajani et al. We use the Wasserstein loss for both the real and generated
  144. # samples, and the gradient penalty loss for the averaged samples.
  145. discriminator_model.compile(optimizer=SGD(clipvalue=0.01), # Adam(lr=4E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
  146. loss=[wasserstein_loss,
  147. wasserstein_loss,
  148. partial_gp_loss], loss_weights=loss_weights2)
  149. # set_trainable_state(discriminator, True)
  150. # set_trainable_state(generator, True)
  151. return generator_model, discriminator_model