123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- from functools import partial
-
- import numpy as np
- from keras import backend as K
- from keras.layers import Input, RepeatVector, Multiply, Lambda, Permute
- from keras.layers.merge import _Merge
- from keras.models import Model
- from keras.optimizers import Adam, SGD
- from keras.losses import binary_crossentropy
-
-
- def wasserstein_loss(y_true, y_pred):
- """Calculates the Wasserstein loss for a sample batch.
-
- The Wasserstein loss function is very simple to calculate. In a standard GAN, the discriminator
- has a sigmoid output, representing the probability that samples are real or generated. In Wasserstein
- GANs, however, the output is linear with no activation function! Instead of being constrained to [0, 1],
- the discriminator wants to make the distance between its output for real and generated samples as large as possible.
-
- The most natural way to achieve this is to label generated samples -1 and real samples 1, instead of the
- 0 and 1 used in normal GANs, so that multiplying the outputs by the labels will give you the loss immediately.
-
- Note that the nature of this loss means that it can be (and frequently will be) less than 0."""
-
- return K.mean(-y_true * y_pred)
-
-
- def generator_recunstruction_loss(y_true, y_pred, enableTrain):
- def rescale_back_labels(labels):
- rescaled_back_labels = (labels / 2) + 0.5
- return rescaled_back_labels
-
- return binary_crossentropy(rescale_back_labels(y_true), rescale_back_labels(y_pred)) # * enableTrain
-
-
- global enable_train
- enable_train = Input(shape=(1,))
- global generator_recunstruction_loss_new
- generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain=enable_train)
- generator_recunstruction_loss_new.__name__ = 'generator_recunstruction_loss_new'
-
-
- def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
- """Calculates the gradient penalty loss for a batch of "averaged" samples.
-
- In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the loss function
- that penalizes the network if the gradient norm moves away from 1. However, it is impossible to evaluate
- this function at all points in the input space. The compromise used in the paper is to choose random points
- on the lines between real and generated samples, and check the gradients at these points. Note that it is the
- gradient w.r.t. the input averaged samples, not the weights of the discriminator, that we're penalizing!
-
- In order to evaluate the gradients, we must first run samples through the generator and evaluate the loss.
- Then we get the gradients of the discriminator w.r.t. the input averaged samples.
- The l2 norm and penalty can then be calculated for this gradient.
-
- Note that this loss function requires the original averaged samples as input, but Keras only supports passing
- y_true and y_pred to loss functions. To get around this, we make a partial() of the function with the
- averaged_samples argument, and use that for model training."""
- # first get the gradients:
- # assuming: - that y_pred has dimensions (batch_size, 1)
- # - averaged_samples has dimensions (batch_size, nbr_features)
- # gradients afterwards has dimension (batch_size, nbr_features), basically
- # a list of nbr_features-dimensional gradient vectors
- gradients = K.gradients(y_pred, averaged_samples)[0]
- # compute the euclidean norm by squaring ...
- gradients_sqr = K.square(gradients)
- # ... summing over the rows ...
- gradients_sqr_sum = K.sum(gradients_sqr,
- axis=np.arange(1, len(gradients_sqr.shape)))
- # ... and sqrt
- gradient_l2_norm = K.sqrt(gradients_sqr_sum)
- # compute lambda * (1 - ||grad||)^2 still for each single sample
- gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
- # return the mean as loss over all the batch samples
- return K.mean(gradient_penalty)
-
-
- def WGAN_wrapper(generator, discriminator, generator_input_shape, discriminator_input_shape, discriminator_input_shape2,
- batch_size, gradient_penalty_weight):
- BATCH_SIZE = batch_size
- GRADIENT_PENALTY_WEIGHT = gradient_penalty_weight
-
- def set_trainable_state(model, state):
- for layer in model.layers:
- layer.trainable = state
- model.trainable = state
-
- class RandomWeightedAverage(_Merge):
- """Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
- between each pair of input points.
-
- Inheriting from _Merge is a little messy but it was the quickest solution I could think of.
- Improvements appreciated."""
-
- def _merge_function(self, inputs):
- weights = K.random_uniform((BATCH_SIZE, 1))
- print(inputs[0])
- return (weights * inputs[0]) + ((1 - weights) * inputs[1])
-
- # The generator_model is used when we want to train the generator layers.
-
- # As such, we ensure that the discriminator layers are not trainable.
- # Note that once we compile this model, updating .trainable will have no effect within it. As such, it
- # won't cause problems if we later set discriminator.trainable = True for the discriminator_model, as long
- # as we compile the generator_model first.
-
-
- set_trainable_state(discriminator, False)
- set_trainable_state(generator, True)
-
- # enable_train = Input(shape = (1,))
- generator_input = Input(shape=generator_input_shape)
- generator_layers = generator(generator_input)
-
- # masked_generator_layers = Lambda(mul)(generator_layers)
-
- # discriminator_noise_in = Input(shape=(1,))
- # input_seq_g = Input(shape = discriminator_input_shape2)
- discriminator_layers_for_generator= discriminator([generator_layers, generator_input])
- generator_model = Model(inputs=[generator_input, enable_train],
- outputs=[generator_layers, discriminator_layers_for_generator])
- # We use the Adam paramaters from Gulrajani et al.
- # global generator_recunstruction_loss_new
- # generator_recunstruction_loss_new = partial(generator_recunstruction_loss, enableTrain = enable_train)
- # generator_recunstruction_loss_new.__name__ = 'generator_RLN'
-
- loss = [generator_recunstruction_loss_new, wasserstein_loss]
- loss_weights = [0, 1]
- generator_model.compile(optimizer=Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
- loss=loss, loss_weights=loss_weights)
-
- # Now that the generator_model is compiled, we can make the discriminator layers trainable.
-
- set_trainable_state(discriminator, True)
- set_trainable_state(generator, False)
-
- # The discriminator_model is more complex. It takes both real image samples and random noise seeds as input.
- # The noise seed is run through the generator model to get generated images. Both real and generated images
- # are then run through the discriminator. Although we could concatenate the real and generated images into a
- # single tensor, we don't (see model compilation for why).
- real_samples = Input(shape=discriminator_input_shape)
- input_seq = Input(shape=discriminator_input_shape2)
-
- generator_input_for_discriminator = Input(shape=generator_input_shape)
- generated_samples_for_discriminator = generator(generator_input_for_discriminator)
- # masked_real_samples = Lambda(mul)(real_samples)
- # masked_generated_samples_for_discriminator= Lambda(mul)(generated_samples_for_discriminator)
- discriminator_output_from_generator = discriminator(
- [generated_samples_for_discriminator, generator_input_for_discriminator])
- discriminator_output_from_real_samples= discriminator([real_samples, input_seq])
-
- # We also need to generate weighted-averages of real and generated samples, to use for the gradient norm penalty.
- averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_discriminator])
- average_seq = RandomWeightedAverage()([input_seq, generator_input_for_discriminator])
- # We then run these samples through the discriminator as well. Note that we never really use the discriminator
- # output for these samples - we're only running them to get the gradient norm for the gradient penalty loss.
- # print('hehehe')
- # print(averaged_samples)
- # masked_averaged_samples = Lambda(mul)(averaged_samples)
- averaged_samples_out = discriminator([averaged_samples, average_seq])
-
- # The gradient penalty loss function requires the input averaged samples to get gradients. However,
- # Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
- # of the function with the averaged samples here.
-
- partial_gp_loss = partial(gradient_penalty_loss,
- averaged_samples=averaged_samples,
- gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
- partial_gp_loss.__name__ = 'gradient_penalty' # Functions need names or Keras will throw an error
-
- # Keras requires that inputs and outputs have the same number of samples. This is why we didn't concatenate the
- # real samples and generated samples before passing them to the discriminator: If we had, it would create an
- # output with 2 * BATCH_SIZE samples, while the output of the "averaged" samples for gradient penalty
- # would have only BATCH_SIZE samples.
-
- # If we don't concatenate the real and generated samples, however, we get three outputs: One of the generated
- # samples, one of the real samples, and one of the averaged samples, all of size BATCH_SIZE. This works neatly!
- discriminator_model = Model(inputs=[real_samples, generator_input_for_discriminator, input_seq],
- outputs=[discriminator_output_from_real_samples,
- discriminator_output_from_generator,
- averaged_samples_out])
- loss_weights2 = [1, 1, 1]
- # We use the Adam paramaters from Gulrajani et al. We use the Wasserstein loss for both the real and generated
- # samples, and the gradient penalty loss for the averaged samples.
- discriminator_model.compile(optimizer=SGD(clipvalue=0.01), # Adam(lr=4E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
- loss=[wasserstein_loss,
- wasserstein_loss,
- partial_gp_loss], loss_weights=loss_weights2)
-
- # set_trainable_state(discriminator, True)
- # set_trainable_state(generator, True)
-
- return generator_model, discriminator_model
|