Self-Attention GAN
3

Self-Attention GAN

This model implements the self-attention GAN architecture introduced by Zhang et al. The generated ImageNet images are reproduced from the paper. The code is a port from the authors' implementation which can be found here.

Architectures have been implemented for the following image dimensions: (128, 128).

Training

To train with default parameters:

mantra train sagan --dataset my_images_data --image-dim 128 128

To choose the type of GAN, one of 'gan', 'lsgan', 'wgan-gp', 'wgan-lp', 'dragan', 'hinge':

mantra train sagan --dataset my_images_data --image-dim 128 128 --gan-type lsgan

To change the gradient penalty lambda:

mantra train sagan --dataset my_images_data --image-dim 128 128 --ld 5.0

To change the number of critics:

mantra train sagan --dataset my_images_data --image-dim 128 128 --n-critic 5.0

Importing

To import this model to your project, run:

mantra import RJT1990/models/sagan

Model
import matplotlib.pyplot as plt import numpy as np import time, itertools import tensorflow as tf from mantraml.models import MantraModel from mantraml.models.tensorflow.summary import FileWriter from mantraml.models.tensorflow.callbacks import ModelCheckpoint, EvaluateTask, StoreTrial, SavePlot from .ops import * class SAGAN(MantraModel): """ This class implements the SAGAN paper https://arxiv.org/pdf/1805.08318.pdf """ model_name = 'Self-Attention GAN' model_image = 'imagenet.png' model_tags = ['self-attention', 'gan'] model_arxiv_id = '1805.08318' def __init__(self, data=None, task=None, **kwargs): self.data = data self.task = task # Configure GPU config = tf.ConfigProto() config.gpu_options.allow_growth = kwargs.get('allow_growth', True) config.gpu_options.per_process_gpu_memory_fraction = kwargs.get('per_process_gpu_memory_fraction', 0.90) if tf.get_default_session() is None: self.session = tf.InteractiveSession(config=config) else: self.session = tf.get_default_session() # Architecture Information self.z_shape = kwargs.get('z_shape', (128)) self.n_strides = kwargs.get('n_strides', (2)) self.learning_rate_gen = kwargs.get('learning_rate', 1e-4) self.learning_rate_dis = kwargs.get('learning_rate', 1e-4) self.beta_1 = kwargs.get('beta_1', 0.0) self.beta_2 = kwargs.get('beta_2', 0.9) self.up_sample = kwargs.get('up_sample', True) # whether to use upsample-conv self.sn = kwargs.get('sn', True) # Whether to use spectral normalization self.gan_type = kwargs.get('gan_type', 'dragan') # One of [gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]; the gan loss self.ld = kwargs.get('ld', 10.0) # THe gradient penalty lambda self.n_critic = kwargs.get('n_critic', 1) # Number of critics def init_model(self): """ This is a wrapper function for initiatilising the model, for example initialisation weights, or loading weights from a past checkpoint Returns ----------- void - updates the model with initialisation variables """ tf.global_variables_initializer().run() self.writer = FileWriter(mantra_model=self) self.summary = tf.summary.merge_all() self.writer.add_graph(self.session.graph) def generator(self, z, training=True, reuse=False): """ This implements the Generator Architecture - can overload this method if desired; default is a DCGAN inspired architecture Parameters ----------- z - tf.placeholder Containing the random noise with which to generate the sample training - bool For batch normalization; if we are training, this should be True; else should be False reuse - bool We reuse the variable scope if we call this method twice Returns ----------- tf.Tensor - a Tensor representing the generated (fake) image """ with tf.variable_scope("generator", reuse=reuse): ch = 1024 x = deconv(z, channels=ch, kernel=4, stride=1, padding='VALID', use_bias=False, sn=self.sn, scope='deconv') x = batch_norm(x, training, scope='batch_norm') x = relu(x) for i in range(self.layer_num // 2): if self.up_sample: x = up_sample(x, scale_factor=2) x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i)) x = batch_norm(x, training, scope='batch_norm_' + str(i)) x = relu(x) else: x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i)) x = batch_norm(x, training, scope='batch_norm_' + str(i)) x = relu(x) ch = ch // 2 # Self Attention x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse) for i in range(self.layer_num // 2, self.layer_num): if self.up_sample: x = up_sample(x, scale_factor=2) x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i)) x = batch_norm(x, training, scope='batch_norm_' + str(i)) x = relu(x) else: x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i)) x = batch_norm(x, training, scope='batch_norm_' + str(i)) x = relu(x) ch = ch // 2 if self.up_sample: x = up_sample(x, scale_factor=2) x = conv(x, channels=self.data.image_shape[2], kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit') x = tanh(x) else: x = deconv(x, channels=self.data.image_shape[2], kernel=4, stride=2, use_bias=False, sn=self.sn, scope='G_deconv_logit') x = tanh(x) return x def discriminator(self, x, training=True, reuse=False): """ This implements the Discriminator Architecture - can overload this method if desired; default is a DCGAN inspired architecture Parameters ----------- x - tf.Tensor Containing the fake (or real) image data training - bool For batch normalization; if we are training, this should be True; else should be False reuse - bool We reuse the variable scope if we call this method twice Returns ----------- tf.Tensor, tf.Tensor - representing the probability the image is true/fake, and the logits (unnormalized probabilities) """ with tf.variable_scope("discriminator", reuse=reuse): ch = 64 x = conv(x, channels=ch, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv') x = lrelu(x, 0.2) for i in range(self.layer_num // 2): x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i)) x = batch_norm(x, training, scope='batch_norm' + str(i)) x = lrelu(x, 0.2) ch = ch * 2 # Self Attention x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse) for i in range(self.layer_num // 2, self.layer_num): x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i)) x = batch_norm(x, training, scope='batch_norm' + str(i)) x = lrelu(x, 0.2) ch = ch * 2 x = conv(x, channels=4, stride=1, sn=self.sn, use_bias=False, scope='D_logit') return x def attention(self, x, ch, sn=False, scope='attention', reuse=False): with tf.variable_scope(scope, reuse=reuse): f = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c'] g = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c'] h = conv(x, ch, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c] # N = h * w s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] beta = tf.nn.softmax(s, axis=-1) # attention map o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) o = tf.reshape(o, shape=x.shape) # [bs, h, w, C] x = gamma * o + x return x def gradient_penalty(self, real, fake): if self.gan_type == 'dragan' : shape = tf.shape(real) eps = tf.random_uniform(shape=shape, minval=0., maxval=1.) x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region noise = 0.5 * x_std * eps # delta in paper # Author suggested U[0,1] in original paper, but he admitted it is bug in github # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided. alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.) interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.) # x_hat should be in the space of X else : alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) interpolated = alpha*real + (1. - alpha)*fake logit = self.discriminator(interpolated, reuse=True) grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm GP = 0 # WGAN - LP if self.gan_type == 'wgan-lp': GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) return GP def show_result(self, num_epoch, path = 'result.png'): """ Full Credit: https://github.com/znxlwm/tensorflow-MNIST-GAN-DCGAN Code modified to have multiple channels. """ fixed_z_ = np.random.normal(0, 1, (self.batch_size, 1, 1, self.z_shape)) test_images = self.session.run(self.x_fake, {self.z: fixed_z_, self.training: False}) size_figure_grid = 5 fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32)) for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) for k in range(size_figure_grid*size_figure_grid): i = k // size_figure_grid j = k % size_figure_grid ax[i, j].cla() ax[i, j].imshow(np.reshape(((test_images[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels))) label = 'Epoch {0}'.format(num_epoch) fig.text(0.5, 0.04, label, ha='center') SavePlot(mantra_model=self, plt=plt, plt_name='faces_epoch_%s.png' % num_epoch) plt.close() def create_loss_function(self): """ This method creates the loss function for the model - here we use a RaGAN Returns ----------- void - updates instance with loss function variables self.d_loss and self.g_loss """ real_logits = self.discriminator(self.x_real) fake_logits = self.discriminator(self.x_fake, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : GP = self.gradient_penalty(real=self.x_real, fake=self.x_fake) else : GP = 0 self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP self.g_loss = generator_loss(self.gan_type, fake=fake_logits) def define_logs(self): """ Define terms to log here Returns ---------- void - updates parameters """ tf.summary.scalar("d_loss", self.d_loss) tf.summary.scalar("g_loss", self.g_loss) def create_optimizers(self): """ This method creates the optimizers for the model - here we use ADAM optimizers. Returns ----------- void - updates instance with optimizer variables self.d_optimizer and self.g_optimizer """ # Differentiable Variables theta = tf.trainable_variables() theta_d = [var for var in theta if var.name.startswith('discriminator')] theta_g = [var for var in theta if var.name.startswith('generator')] # optimizer for each network with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): self.d_optimizer = tf.train.AdamOptimizer(self.learning_rate_dis, beta1=self.beta_1, beta2=self.beta_2).minimize(self.d_loss, var_list=theta_d) self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate_gen, beta1=self.beta_1, beta2=self.beta_2).minimize(self.g_loss, var_list=theta_g) def build_model(self): """ This method constructs the model, including the loss function and optimization routine Returns ----------- void - constructs model objects that are stored to the model instance """ if self.data.image_shape[0] == 128: self.generator = self.generator self.discriminator = self.discriminator else: raise ValueError('Your entered image dimension is %s. Only (128, 128) is supported!' % '(%s, %s)' % (self.data.image_shape[0], self.data.image_shape[1])) self.layer_num = int(np.log2(self.data.image_shape[0])) - 3 # Input Variables self.x_real = tf.placeholder(tf.float32, shape=(self.batch_size, self.data.image_shape[0], self.data.image_shape[1], self.data.image_shape[2]), name='real_images') # for use of real image input self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1, 1, self.z_shape), name='z') # the random noise used to generate the images self.training = tf.placeholder(dtype=tf.bool) # boolean for batchnorm denoting that training is activated self.x_fake = self.generator(self.z, self.training) # fake input generated by the Generator Network # Discriminator and Generator Loss Functions self.create_loss_function() # Create Optimizers for Traing self.create_optimizers() def gradient_update(self, iter): """ Updates the parameters with a single gradient update Parameters ---------- iter - int The iteration number Returns ---------- void - updates parameters """ # Discriminator Update x = self.data.X[iter*self.batch_size:(iter+1)*self.batch_size] z = np.random.normal(0, 1, (self.batch_size, 1, 1, self.z_shape)) discriminator_loss, _ = self.session.run([self.d_loss, self.d_optimizer], {self.x_real: x, self.z: z, self.training: True}) # Generator Update z = np.random.normal(0, 1, (self.batch_size, 1, 1, self.z_shape)) summary, generator_loss, _ = self.session.run([self.summary, self.g_loss, self.g_optimizer], {self.z: z, self.x_real: x, self.training: True}) self.writer.add_summary(summary, iter) def end_of_epoch_update(self, epoch): """ Update to apply at the end of the epoch """ epoch_run_time = time.time() - self.epoch_start_time if epoch % 1 == 0: self.show_result((epoch + 1), path='') def end_of_training_update(self): """ Update to apply at the end of training """ self.session.close() def run(self): """ Runs the training. """ # Build and initialize self.build_model() if self.trial: self.define_logs() self.init_model() # Ready data self.batches_per_epoch = len(self.data) // self.batch_size # Results Dict np.random.seed(int(time.time())) # random seed for training for epoch in range(self.epochs): self.epoch_start_time = time.time() for iter in range(self.batches_per_epoch): self.gradient_update(iter) self.end_of_epoch_update(epoch) ModelCheckpoint(mantra_model=self, session=self.session) if self.task: EvaluateTask(mantra_model=self) StoreTrial(mantra_model=self, epoch=epoch) self.end_of_epoch_message(epoch=epoch, message=str(time.time() - self.epoch_start_time)) self.end_of_training_update()
Code
sagan / __init__.py
1 lines | 49 bytes