import matplotlib.pyplot as plt
import numpy as np
import time, itertools
import tensorflow as tf
from mantraml.models import MantraModel
from mantraml.models.tensorflow.summary import FileWriter
from mantraml.models.tensorflow.callbacks import ModelCheckpoint, EvaluateTask, StoreTrial, SavePlot
from .ops import *
class SAGAN(MantraModel):
"""
This class implements the SAGAN paper https://arxiv.org/pdf/1805.08318.pdf
"""
model_name = 'Self-Attention GAN'
model_image = 'imagenet.png'
model_tags = ['self-attention', 'gan']
model_arxiv_id = '1805.08318'
def __init__(self, data=None, task=None, **kwargs):
self.data = data
self.task = task
# Configure GPU
config = tf.ConfigProto()
config.gpu_options.allow_growth = kwargs.get('allow_growth', True)
config.gpu_options.per_process_gpu_memory_fraction = kwargs.get('per_process_gpu_memory_fraction', 0.90)
if tf.get_default_session() is None:
self.session = tf.InteractiveSession(config=config)
else:
self.session = tf.get_default_session()
# Architecture Information
self.z_shape = kwargs.get('z_shape', (128))
self.n_strides = kwargs.get('n_strides', (2))
self.learning_rate_gen = kwargs.get('learning_rate', 1e-4)
self.learning_rate_dis = kwargs.get('learning_rate', 1e-4)
self.beta_1 = kwargs.get('beta_1', 0.0)
self.beta_2 = kwargs.get('beta_2', 0.9)
self.up_sample = kwargs.get('up_sample', True) # whether to use upsample-conv
self.sn = kwargs.get('sn', True) # Whether to use spectral normalization
self.gan_type = kwargs.get('gan_type', 'dragan') # One of [gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]; the gan loss
self.ld = kwargs.get('ld', 10.0) # THe gradient penalty lambda
self.n_critic = kwargs.get('n_critic', 1) # Number of critics
def init_model(self):
"""
This is a wrapper function for initiatilising the model, for example initialisation weights, or loading
weights from a past checkpoint
Returns
-----------
void - updates the model with initialisation variables
"""
tf.global_variables_initializer().run()
self.writer = FileWriter(mantra_model=self)
self.summary = tf.summary.merge_all()
self.writer.add_graph(self.session.graph)
def generator(self, z, training=True, reuse=False):
"""
This implements the Generator Architecture - can overload this method if desired; default is a
DCGAN inspired architecture
Parameters
-----------
z - tf.placeholder
Containing the random noise with which to generate the sample
training - bool
For batch normalization; if we are training, this should be True; else should be False
reuse - bool
We reuse the variable scope if we call this method twice
Returns
-----------
tf.Tensor - a Tensor representing the generated (fake) image
"""
with tf.variable_scope("generator", reuse=reuse):
ch = 1024
x = deconv(z, channels=ch, kernel=4, stride=1, padding='VALID', use_bias=False, sn=self.sn, scope='deconv')
x = batch_norm(x, training, scope='batch_norm')
x = relu(x)
for i in range(self.layer_num // 2):
if self.up_sample:
x = up_sample(x, scale_factor=2)
x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
x = batch_norm(x, training, scope='batch_norm_' + str(i))
x = relu(x)
else:
x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
x = batch_norm(x, training, scope='batch_norm_' + str(i))
x = relu(x)
ch = ch // 2
# Self Attention
x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
for i in range(self.layer_num // 2, self.layer_num):
if self.up_sample:
x = up_sample(x, scale_factor=2)
x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
x = batch_norm(x, training, scope='batch_norm_' + str(i))
x = relu(x)
else:
x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
x = batch_norm(x, training, scope='batch_norm_' + str(i))
x = relu(x)
ch = ch // 2
if self.up_sample:
x = up_sample(x, scale_factor=2)
x = conv(x, channels=self.data.image_shape[2], kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit')
x = tanh(x)
else:
x = deconv(x, channels=self.data.image_shape[2], kernel=4, stride=2, use_bias=False, sn=self.sn, scope='G_deconv_logit')
x = tanh(x)
return x
def discriminator(self, x, training=True, reuse=False):
"""
This implements the Discriminator Architecture - can overload this method if desired; default is a DCGAN inspired architecture
Parameters
-----------
x - tf.Tensor
Containing the fake (or real) image data
training - bool
For batch normalization; if we are training, this should be True; else should be False
reuse - bool
We reuse the variable scope if we call this method twice
Returns
-----------
tf.Tensor, tf.Tensor - representing the probability the image is true/fake, and the logits (unnormalized probabilities)
"""
with tf.variable_scope("discriminator", reuse=reuse):
ch = 64
x = conv(x, channels=ch, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv')
x = lrelu(x, 0.2)
for i in range(self.layer_num // 2):
x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
x = batch_norm(x, training, scope='batch_norm' + str(i))
x = lrelu(x, 0.2)
ch = ch * 2
# Self Attention
x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
for i in range(self.layer_num // 2, self.layer_num):
x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
x = batch_norm(x, training, scope='batch_norm' + str(i))
x = lrelu(x, 0.2)
ch = ch * 2
x = conv(x, channels=4, stride=1, sn=self.sn, use_bias=False, scope='D_logit')
return x
def attention(self, x, ch, sn=False, scope='attention', reuse=False):
with tf.variable_scope(scope, reuse=reuse):
f = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']
g = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']
h = conv(x, ch, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]
# N = h * w
s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
beta = tf.nn.softmax(s, axis=-1) # attention map
o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
x = gamma * o + x
return x
def gradient_penalty(self, real, fake):
if self.gan_type == 'dragan' :
shape = tf.shape(real)
eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
noise = 0.5 * x_std * eps # delta in paper
# Author suggested U[0,1] in original paper, but he admitted it is bug in github
# (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.
alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.)
interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.) # x_hat should be in the space of X
else :
alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolated = alpha*real + (1. - alpha)*fake
logit = self.discriminator(interpolated, reuse=True)
grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
GP = 0
# WGAN - LP
if self.gan_type == 'wgan-lp':
GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
return GP
def show_result(self, num_epoch, path = 'result.png'):
"""
Full Credit: https://github.com/znxlwm/tensorflow-MNIST-GAN-DCGAN
Code modified to have multiple channels.
"""
fixed_z_ = np.random.normal(0, 1, (self.batch_size, 1, 1, self.z_shape))
test_images = self.session.run(self.x_fake, {self.z: fixed_z_, self.training: False})
size_figure_grid = 5
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for k in range(size_figure_grid*size_figure_grid):
i = k // size_figure_grid
j = k % size_figure_grid
ax[i, j].cla()
ax[i, j].imshow(np.reshape(((test_images[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels)))
label = 'Epoch {0}'.format(num_epoch)
fig.text(0.5, 0.04, label, ha='center')
SavePlot(mantra_model=self, plt=plt, plt_name='faces_epoch_%s.png' % num_epoch)
plt.close()
def create_loss_function(self):
"""
This method creates the loss function for the model - here we use a RaGAN
Returns
-----------
void - updates instance with loss function variables self.d_loss and self.g_loss
"""
real_logits = self.discriminator(self.x_real)
fake_logits = self.discriminator(self.x_fake, reuse=True)
if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
GP = self.gradient_penalty(real=self.x_real, fake=self.x_fake)
else :
GP = 0
self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
def define_logs(self):
"""
Define terms to log here
Returns
----------
void - updates parameters
"""
tf.summary.scalar("d_loss", self.d_loss)
tf.summary.scalar("g_loss", self.g_loss)
def create_optimizers(self):
"""
This method creates the optimizers for the model - here we use ADAM optimizers.
Returns
-----------
void - updates instance with optimizer variables self.d_optimizer and self.g_optimizer
"""
# Differentiable Variables
theta = tf.trainable_variables()
theta_d = [var for var in theta if var.name.startswith('discriminator')]
theta_g = [var for var in theta if var.name.startswith('generator')]
# optimizer for each network
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optimizer = tf.train.AdamOptimizer(self.learning_rate_dis, beta1=self.beta_1, beta2=self.beta_2).minimize(self.d_loss, var_list=theta_d)
self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate_gen, beta1=self.beta_1, beta2=self.beta_2).minimize(self.g_loss, var_list=theta_g)
def build_model(self):
"""
This method constructs the model, including the loss function and optimization routine
Returns
-----------
void - constructs model objects that are stored to the model instance
"""
if self.data.image_shape[0] == 128:
self.generator = self.generator
self.discriminator = self.discriminator
else:
raise ValueError('Your entered image dimension is %s. Only (128, 128) is supported!' % '(%s, %s)' % (self.data.image_shape[0], self.data.image_shape[1]))
self.layer_num = int(np.log2(self.data.image_shape[0])) - 3
# Input Variables
self.x_real = tf.placeholder(tf.float32, shape=(self.batch_size, self.data.image_shape[0], self.data.image_shape[1], self.data.image_shape[2]), name='real_images') # for use of real image input
self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1, 1, self.z_shape), name='z') # the random noise used to generate the images
self.training = tf.placeholder(dtype=tf.bool) # boolean for batchnorm denoting that training is activated
self.x_fake = self.generator(self.z, self.training) # fake input generated by the Generator Network
# Discriminator and Generator Loss Functions
self.create_loss_function()
# Create Optimizers for Traing
self.create_optimizers()
def gradient_update(self, iter):
"""
Updates the parameters with a single gradient update
Parameters
----------
iter - int
The iteration number
Returns
----------
void - updates parameters
"""
# Discriminator Update
x = self.data.X[iter*self.batch_size:(iter+1)*self.batch_size]
z = np.random.normal(0, 1, (self.batch_size, 1, 1, self.z_shape))
discriminator_loss, _ = self.session.run([self.d_loss, self.d_optimizer], {self.x_real: x, self.z: z, self.training: True})
# Generator Update
z = np.random.normal(0, 1, (self.batch_size, 1, 1, self.z_shape))
summary, generator_loss, _ = self.session.run([self.summary, self.g_loss, self.g_optimizer], {self.z: z, self.x_real: x, self.training: True})
self.writer.add_summary(summary, iter)
def end_of_epoch_update(self, epoch):
"""
Update to apply at the end of the epoch
"""
epoch_run_time = time.time() - self.epoch_start_time
if epoch % 1 == 0:
self.show_result((epoch + 1), path='')
def end_of_training_update(self):
"""
Update to apply at the end of training
"""
self.session.close()
def run(self):
"""
Runs the training.
"""
# Build and initialize
self.build_model()
if self.trial:
self.define_logs()
self.init_model()
# Ready data
self.batches_per_epoch = len(self.data) // self.batch_size
# Results Dict
np.random.seed(int(time.time())) # random seed for training
for epoch in range(self.epochs):
self.epoch_start_time = time.time()
for iter in range(self.batches_per_epoch):
self.gradient_update(iter)
self.end_of_epoch_update(epoch)
ModelCheckpoint(mantra_model=self, session=self.session)
if self.task:
EvaluateTask(mantra_model=self)
StoreTrial(mantra_model=self, epoch=epoch)
self.end_of_epoch_message(epoch=epoch, message=str(time.time() - self.epoch_start_time))
self.end_of_training_update()