from __future__ import division
import time
import tensorflow as tf
import itertools
import tensorflow.contrib.slim as slim
import matplotlib.pyplot as plt
import numpy as np
from .utils import *
from mantraml.models import MantraModel
from mantraml.models.tensorflow.summary import FileWriter
from mantraml.models.tensorflow.callbacks import ModelCheckpoint, EvaluateTask, StoreTrial, SavePlot
def abs_criterion(in_, target):
return tf.reduce_mean(tf.abs(in_ - target))
def mae_criterion(in_, target):
return tf.reduce_mean((in_-target)**2)
def sce_criterion(logits, labels):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
def instance_norm(input, name="instance_norm"):
with tf.variable_scope(name):
depth = input.get_shape()[3]
scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0))
mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True)
epsilon = 1e-5
inv = tf.rsqrt(variance + epsilon)
normalized = (input-mean)*inv
return scale*normalized + offset
def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME', name="conv2d"):
with tf.variable_scope(name):
return slim.conv2d(input_, output_dim, ks, s, padding=padding, activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=None)
def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"):
with tf.variable_scope(name):
return slim.conv2d_transpose(input_, output_dim, ks, s, padding='SAME', activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=None)
def lrelu(x, leak=0.2, name="lrelu"):
return tf.maximum(x, leak*x)
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [input_.get_shape()[-1], output_size], tf.float32,
tf.random_normal_initializer(stddev=stddev))
bias = tf.get_variable("bias", [output_size],
initializer=tf.constant_initializer(bias_start))
if with_w:
return tf.matmul(input_, matrix) + bias, matrix, bias
else:
return tf.matmul(input_, matrix) + bias
class CycleGAN(MantraModel):
"""
Implements CycleGAN - https://arxiv.org/pdf/1703.10593.pdf
"""
model_name = "CycleGAN"
model_image = "horse2zebra.gif"
model_tags = ['cycle', 'gan']
model_arxiv_id = '1703.10593'
def __init__(self, data=None, task=None, **kwargs):
self.data = data
self.task = task
# Configure GPU
config = tf.ConfigProto()
config.gpu_options.allow_growth = kwargs.get('allow_growth', True)
config.gpu_options.per_process_gpu_memory_fraction = kwargs.get('per_process_gpu_memory_fraction', 0.90)
if tf.get_default_session() is None:
self.sess = tf.InteractiveSession(config=config)
else:
self.sess = tf.get_default_session()
self.l1_lambda = kwargs.get('l1_lambda', 10.0)
self.epoch_step = kwargs.get('epoch_step', 100)
self.lr = kwargs.get('lr', 0.0001)
self.use_resnet = kwargs.get('use_resnet', True)
self.use_lsgan = kwargs.get('use_lsgan', True)
self.beta1 = kwargs.get('beta1', 0.5)
self.max_size = kwargs.get('max_size', 50)
if self.use_lsgan:
self.criterionGAN = mae_criterion
else:
self.criterionGAN = sce_criterion
def generator(self, image, training=True, reuse=False, name='generator'):
"""
This implements the Generator Architecture - using residual blocks
Parameters
-----------
image - tf.placeholder
Containing the images
training - bool
For batch normalization; if we are training, this should be True; else should be False
reuse - bool
We reuse the variable scope if we call this method twice
Returns
-----------
tf.Tensor - a Tensor representing the generated (fake) image
"""
gf_dim = 64
with tf.variable_scope(name, reuse=reuse):
def residule_block(x, dim, ks=3, s=1, name='res'):
p = int((ks - 1) / 2)
y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1')
y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2')
return y + x
# Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
# The network with 9 blocks consists of: c7s1-32, d64, d128, R128, R128, R128,
# R128, R128, R128, R128, R128, R128, u64, u32, c7s1-3
c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
c1 = tf.nn.relu(instance_norm(conv2d(c0, gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn'))
c2 = tf.nn.relu(instance_norm(conv2d(c1, gf_dim*2, 3, 2, name='g_e2_c'), 'g_e2_bn'))
c3 = tf.nn.relu(instance_norm(conv2d(c2, gf_dim*4, 3, 2, name='g_e3_c'), 'g_e3_bn'))
# define G network with 9 resnet blocks
r1 = residule_block(c3, gf_dim*4, name='g_r1')
r2 = residule_block(r1, gf_dim*4, name='g_r2')
r3 = residule_block(r2, gf_dim*4, name='g_r3')
r4 = residule_block(r3, gf_dim*4, name='g_r4')
r5 = residule_block(r4, gf_dim*4, name='g_r5')
r6 = residule_block(r5, gf_dim*4, name='g_r6')
r7 = residule_block(r6, gf_dim*4, name='g_r7')
r8 = residule_block(r7, gf_dim*4, name='g_r8')
r9 = residule_block(r8, gf_dim*4, name='g_r9')
d1 = deconv2d(r9, gf_dim*2, 3, 2, name='g_d1_dc')
d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn'))
d2 = deconv2d(d1, gf_dim, 3, 2, name='g_d2_dc')
d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn'))
d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
pred = tf.nn.tanh(conv2d(d2, self.data.n_color_channels, 7, 1, padding='VALID', name='g_pred_c'))
return pred
def discriminator(self, image, training=True, reuse=False, name='discriminator'):
"""
This implements the Discriminator Architecture
Parameters
-----------
x - tf.Tensor
Containing the fake (or real) image data
training - bool
For batch normalization; if we are training, this should be True; else should be False
reuse - bool
We reuse the variable scope if we call this method twice
Returns
-----------
tf.Tensor, tf.Tensor - representing the probability the image is true/fake, and the logits (unnormalized probabilities)
"""
df_dim = 64
with tf.variable_scope(name, reuse=reuse):
# image is 256 x 256 x input_c_dim
h0 = lrelu(conv2d(image, df_dim, name='d_h0_conv'))
# h0 is (128 x 128 x self.df_dim)
h1 = lrelu(instance_norm(conv2d(h0, df_dim*2, name='d_h1_conv'), 'd_bn1'))
# h1 is (64 x 64 x self.df_dim*2)
h2 = lrelu(instance_norm(conv2d(h1, df_dim*4, name='d_h2_conv'), 'd_bn2'))
# h2 is (32x 32 x self.df_dim*4)
h3 = lrelu(instance_norm(conv2d(h2, df_dim*8, s=1, name='d_h3_conv'), 'd_bn3'))
# h3 is (32 x 32 x self.df_dim*8)
h4 = conv2d(h3, 1, s=1, name='d_h3_pred')
# h4 is (32 x 32 x 1)
return h4
def build_model(self):
self.real_A = tf.placeholder(tf.float32,
[None, self.data.image_shape[0], self.data.image_shape[0],
self.data.n_color_channels],
name='real_A_images')
self.real_B = tf.placeholder(tf.float32,
[None, self.data.image_shape[0], self.data.image_shape[0],
self.data.n_color_channels],
name='real_B_images')
self.fake_B = self.generator(self.real_A, reuse=False, name="generatorA2B")
self.fake_A_ = self.generator(self.fake_B, reuse=False, name="generatorB2A")
self.fake_A = self.generator(self.real_B, reuse=True, name="generatorB2A")
self.fake_B_ = self.generator(self.fake_A, reuse=True, name="generatorA2B")
self.DB_fake = self.discriminator(self.fake_B, reuse=False, name="discriminatorB")
self.DA_fake = self.discriminator(self.fake_A, reuse=False, name="discriminatorA")
self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+ self.l1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.l1_lambda * abs_criterion(self.real_B, self.fake_B_)
self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+ self.l1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.l1_lambda * abs_criterion(self.real_B, self.fake_B_)
self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+ self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+ self.l1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.l1_lambda * abs_criterion(self.real_B, self.fake_B_)
self.fake_A_sample = tf.placeholder(tf.float32,
[None, self.data.image_shape[0], self.data.image_shape[0],
self.data.n_color_channels], name='fake_A_sample')
self.fake_B_sample = tf.placeholder(tf.float32,
[None, self.data.image_shape[0], self.data.image_shape[0],
self.data.n_color_channels], name='fake_B_sample')
self.DB_real = self.discriminator(self.real_B, reuse=True, name="discriminatorB")
self.DA_real = self.discriminator(self.real_A, reuse=True, name="discriminatorA")
self.DB_fake_sample = self.discriminator(self.fake_B_sample, reuse=True, name="discriminatorB")
self.DA_fake_sample = self.discriminator(self.fake_A_sample, reuse=True, name="discriminatorA")
self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real))
self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample))
self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2
self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real))
self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample))
self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2
self.d_loss = self.da_loss + self.db_loss
self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b)
self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a)
self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum])
self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss)
self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss)
self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real)
self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake)
self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real)
self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake)
self.d_sum = tf.summary.merge(
[self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum,
self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum,
self.d_loss_sum]
)
self.test_A = tf.placeholder(tf.float32,
[None, self.data.image_shape[0], self.data.image_shape[0],
self.data.n_color_channels], name='test_A')
self.test_B = tf.placeholder(tf.float32,
[None, self.data.image_shape[0], self.data.image_shape[0],
self.data.n_color_channels], name='test_B')
self.testB = self.generator(self.test_A, reuse=True, name="generatorA2B")
self.testA = self.generator(self.test_B, reuse=True, name="generatorB2A")
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
self.g_vars = [var for var in t_vars if 'generator' in var.name]
self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=self.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)
self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=self.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)
def init_model(self):
"""
This is a wrapper function for initiatilising the model, for example initialisation weights, or loading
weights from a past checkpoint
Returns
-----------
void - updates the model with initialisation variables
"""
tf.global_variables_initializer().run()
self.writer = FileWriter(mantra_model=self)
self.summary = tf.summary.merge_all()
self.writer.add_graph(self.sess.graph)
def run(self):
"""
Runs the training.
"""
# Build and initialize
self.build_model()
self.init_model()
# Ready data
self.batches_per_epoch = min(self.data.X.shape[0], self.data.Y.shape[0]) // self.batch_size
# Results Dict
np.random.seed(int(time.time())) # random seed for training
for epoch in range(self.epochs):
self.epoch_start_time = time.time()
for iter in range(self.batches_per_epoch):
self.gradient_update(iter)
self.end_of_epoch_update(epoch)
ModelCheckpoint(mantra_model=self, session=self.sess)
if self.task:
EvaluateTask(mantra_model=self)
StoreTrial(mantra_model=self, epoch=epoch)
self.end_of_epoch_message(epoch=epoch, message=str(time.time() - self.epoch_start_time))
self.end_of_training_update()
def gradient_update(self, iter):
"""
Updates the parameters with a single gradient update
Parameters
----------
iter - int
The iteration number
Returns
----------
void - updates parameters
"""
#lr = self.lr if iter < self.epoch_step else self.lr*(self.epochs-iter)/(self.epochs-self.epoch_step)
# Discriminator Update
x = self.data.X[iter*self.batch_size:(iter+1)*self.batch_size]
y = self.data.Y[iter*self.batch_size:(iter+1)*self.batch_size]
# Update G network and record fake outputs
fake_A, fake_B, _, summary_str = self.sess.run(
[self.fake_A, self.fake_B, self.g_optim, self.g_sum],
feed_dict={self.real_A: x, self.real_B: y})
self.writer.add_summary(summary_str, iter)
# Update D network
_, summary_str = self.sess.run(
[self.d_optim, self.d_sum],
feed_dict={self.real_A: x,
self.real_B: y,
self.fake_A_sample: fake_A,
self.fake_B_sample: fake_B})
self.writer.add_summary(summary_str, iter)
def end_of_epoch_update(self, epoch):
fake_A, fake_B = self.sess.run(
[self.fake_A, self.fake_B],
feed_dict={self.real_A: self.data.X[:9], self.real_B: self.data.Y[:9]}
)
size_figure_grid = 3
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for k in range(size_figure_grid*size_figure_grid):
i = k // size_figure_grid
j = k % size_figure_grid
ax[i, j].cla()
ax[i, j].imshow(np.reshape(((self.data.X[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels)))
label = 'Epoch {0}'.format(epoch)
fig.text(0.5, 0.04, label, ha='center')
SavePlot(mantra_model=self, plt=plt, plt_name='real_A_%s.png' % epoch)
plt.close()
size_figure_grid = 3
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for k in range(size_figure_grid*size_figure_grid):
i = k // size_figure_grid
j = k % size_figure_grid
ax[i, j].cla()
ax[i, j].imshow(np.reshape(((self.data.Y[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels)))
label = 'Epoch {0}'.format(epoch)
fig.text(0.5, 0.04, label, ha='center')
SavePlot(mantra_model=self, plt=plt, plt_name='real_B_%s.png' % epoch)
plt.close()
size_figure_grid = 3
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for k in range(size_figure_grid*size_figure_grid):
i = k // size_figure_grid
j = k % size_figure_grid
ax[i, j].cla()
ax[i, j].imshow(np.reshape(((fake_A[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels)))
label = 'Epoch {0}'.format(epoch)
fig.text(0.5, 0.04, label, ha='center')
SavePlot(mantra_model=self, plt=plt, plt_name='A_%s.png' % epoch)
plt.close()
size_figure_grid = 3
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for k in range(size_figure_grid*size_figure_grid):
i = k // size_figure_grid
j = k % size_figure_grid
ax[i, j].cla()
ax[i, j].imshow(np.reshape(((fake_B[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels)))
label = 'Epoch {0}'.format(epoch)
fig.text(0.5, 0.04, label, ha='center')
SavePlot(mantra_model=self, plt=plt, plt_name='B_%s.png' % epoch)
plt.close()