CycleGAN
4

CycleGAN

This model implements the CycleGAN conditional GAN architecture introduced by Zhu et al. The Horse to Zebra gif is reproduced from the authors' project. The code is a port from the authors' implementation which can be found here.

Architectures have been implemented for the following image dimensions: (256, 256).

This model references dataset attributes self.data.X and self.data.Y as the two domains: compatible datasets will process images and store them in these attributes. See the Horse to Zebra dataset for an example template.

Training

To train with default parameters:

mantra train cyclegan --dataset my_images_data --image-dim 256 256

Importing

To import this model to your project, run:

mantra import RJT1990/models/cyclegan

Model
from __future__ import division import time import tensorflow as tf import itertools import tensorflow.contrib.slim as slim import matplotlib.pyplot as plt import numpy as np from .utils import * from mantraml.models import MantraModel from mantraml.models.tensorflow.summary import FileWriter from mantraml.models.tensorflow.callbacks import ModelCheckpoint, EvaluateTask, StoreTrial, SavePlot def abs_criterion(in_, target): return tf.reduce_mean(tf.abs(in_ - target)) def mae_criterion(in_, target): return tf.reduce_mean((in_-target)**2) def sce_criterion(logits, labels): return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)) def instance_norm(input, name="instance_norm"): with tf.variable_scope(name): depth = input.get_shape()[3] scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)) offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0)) mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True) epsilon = 1e-5 inv = tf.rsqrt(variance + epsilon) normalized = (input-mean)*inv return scale*normalized + offset def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME', name="conv2d"): with tf.variable_scope(name): return slim.conv2d(input_, output_dim, ks, s, padding=padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=None) def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"): with tf.variable_scope(name): return slim.conv2d_transpose(input_, output_dim, ks, s, padding='SAME', activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=None) def lrelu(x, leak=0.2, name="lrelu"): return tf.maximum(x, leak*x) def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): with tf.variable_scope(scope or "Linear"): matrix = tf.get_variable("Matrix", [input_.get_shape()[-1], output_size], tf.float32, tf.random_normal_initializer(stddev=stddev)) bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(bias_start)) if with_w: return tf.matmul(input_, matrix) + bias, matrix, bias else: return tf.matmul(input_, matrix) + bias class CycleGAN(MantraModel): """ Implements CycleGAN - https://arxiv.org/pdf/1703.10593.pdf """ model_name = "CycleGAN" model_image = "horse2zebra.gif" model_tags = ['cycle', 'gan'] model_arxiv_id = '1703.10593' def __init__(self, data=None, task=None, **kwargs): self.data = data self.task = task # Configure GPU config = tf.ConfigProto() config.gpu_options.allow_growth = kwargs.get('allow_growth', True) config.gpu_options.per_process_gpu_memory_fraction = kwargs.get('per_process_gpu_memory_fraction', 0.90) if tf.get_default_session() is None: self.sess = tf.InteractiveSession(config=config) else: self.sess = tf.get_default_session() self.l1_lambda = kwargs.get('l1_lambda', 10.0) self.epoch_step = kwargs.get('epoch_step', 100) self.lr = kwargs.get('lr', 0.0001) self.use_resnet = kwargs.get('use_resnet', True) self.use_lsgan = kwargs.get('use_lsgan', True) self.beta1 = kwargs.get('beta1', 0.5) self.max_size = kwargs.get('max_size', 50) if self.use_lsgan: self.criterionGAN = mae_criterion else: self.criterionGAN = sce_criterion def generator(self, image, training=True, reuse=False, name='generator'): """ This implements the Generator Architecture - using residual blocks Parameters ----------- image - tf.placeholder Containing the images training - bool For batch normalization; if we are training, this should be True; else should be False reuse - bool We reuse the variable scope if we call this method twice Returns ----------- tf.Tensor - a Tensor representing the generated (fake) image """ gf_dim = 64 with tf.variable_scope(name, reuse=reuse): def residule_block(x, dim, ks=3, s=1, name='res'): p = int((ks - 1) / 2) y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1') y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2') return y + x # Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/ # The network with 9 blocks consists of: c7s1-32, d64, d128, R128, R128, R128, # R128, R128, R128, R128, R128, R128, u64, u32, c7s1-3 c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") c1 = tf.nn.relu(instance_norm(conv2d(c0, gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn')) c2 = tf.nn.relu(instance_norm(conv2d(c1, gf_dim*2, 3, 2, name='g_e2_c'), 'g_e2_bn')) c3 = tf.nn.relu(instance_norm(conv2d(c2, gf_dim*4, 3, 2, name='g_e3_c'), 'g_e3_bn')) # define G network with 9 resnet blocks r1 = residule_block(c3, gf_dim*4, name='g_r1') r2 = residule_block(r1, gf_dim*4, name='g_r2') r3 = residule_block(r2, gf_dim*4, name='g_r3') r4 = residule_block(r3, gf_dim*4, name='g_r4') r5 = residule_block(r4, gf_dim*4, name='g_r5') r6 = residule_block(r5, gf_dim*4, name='g_r6') r7 = residule_block(r6, gf_dim*4, name='g_r7') r8 = residule_block(r7, gf_dim*4, name='g_r8') r9 = residule_block(r8, gf_dim*4, name='g_r9') d1 = deconv2d(r9, gf_dim*2, 3, 2, name='g_d1_dc') d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn')) d2 = deconv2d(d1, gf_dim, 3, 2, name='g_d2_dc') d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn')) d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") pred = tf.nn.tanh(conv2d(d2, self.data.n_color_channels, 7, 1, padding='VALID', name='g_pred_c')) return pred def discriminator(self, image, training=True, reuse=False, name='discriminator'): """ This implements the Discriminator Architecture Parameters ----------- x - tf.Tensor Containing the fake (or real) image data training - bool For batch normalization; if we are training, this should be True; else should be False reuse - bool We reuse the variable scope if we call this method twice Returns ----------- tf.Tensor, tf.Tensor - representing the probability the image is true/fake, and the logits (unnormalized probabilities) """ df_dim = 64 with tf.variable_scope(name, reuse=reuse): # image is 256 x 256 x input_c_dim h0 = lrelu(conv2d(image, df_dim, name='d_h0_conv')) # h0 is (128 x 128 x self.df_dim) h1 = lrelu(instance_norm(conv2d(h0, df_dim*2, name='d_h1_conv'), 'd_bn1')) # h1 is (64 x 64 x self.df_dim*2) h2 = lrelu(instance_norm(conv2d(h1, df_dim*4, name='d_h2_conv'), 'd_bn2')) # h2 is (32x 32 x self.df_dim*4) h3 = lrelu(instance_norm(conv2d(h2, df_dim*8, s=1, name='d_h3_conv'), 'd_bn3')) # h3 is (32 x 32 x self.df_dim*8) h4 = conv2d(h3, 1, s=1, name='d_h3_pred') # h4 is (32 x 32 x 1) return h4 def build_model(self): self.real_A = tf.placeholder(tf.float32, [None, self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels], name='real_A_images') self.real_B = tf.placeholder(tf.float32, [None, self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels], name='real_B_images') self.fake_B = self.generator(self.real_A, reuse=False, name="generatorA2B") self.fake_A_ = self.generator(self.fake_B, reuse=False, name="generatorB2A") self.fake_A = self.generator(self.real_B, reuse=True, name="generatorB2A") self.fake_B_ = self.generator(self.fake_A, reuse=True, name="generatorA2B") self.DB_fake = self.discriminator(self.fake_B, reuse=False, name="discriminatorB") self.DA_fake = self.discriminator(self.fake_A, reuse=False, name="discriminatorA") self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \ + self.l1_lambda * abs_criterion(self.real_A, self.fake_A_) \ + self.l1_lambda * abs_criterion(self.real_B, self.fake_B_) self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \ + self.l1_lambda * abs_criterion(self.real_A, self.fake_A_) \ + self.l1_lambda * abs_criterion(self.real_B, self.fake_B_) self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \ + self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \ + self.l1_lambda * abs_criterion(self.real_A, self.fake_A_) \ + self.l1_lambda * abs_criterion(self.real_B, self.fake_B_) self.fake_A_sample = tf.placeholder(tf.float32, [None, self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels], name='fake_A_sample') self.fake_B_sample = tf.placeholder(tf.float32, [None, self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels], name='fake_B_sample') self.DB_real = self.discriminator(self.real_B, reuse=True, name="discriminatorB") self.DA_real = self.discriminator(self.real_A, reuse=True, name="discriminatorA") self.DB_fake_sample = self.discriminator(self.fake_B_sample, reuse=True, name="discriminatorB") self.DA_fake_sample = self.discriminator(self.fake_A_sample, reuse=True, name="discriminatorA") self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real)) self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample)) self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2 self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real)) self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample)) self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2 self.d_loss = self.da_loss + self.db_loss self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b) self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a) self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum]) self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss) self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss) self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real) self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake) self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real) self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake) self.d_sum = tf.summary.merge( [self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum, self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum, self.d_loss_sum] ) self.test_A = tf.placeholder(tf.float32, [None, self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels], name='test_A') self.test_B = tf.placeholder(tf.float32, [None, self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels], name='test_B') self.testB = self.generator(self.test_A, reuse=True, name="generatorA2B") self.testA = self.generator(self.test_B, reuse=True, name="generatorB2A") t_vars = tf.trainable_variables() self.d_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_vars = [var for var in t_vars if 'generator' in var.name] self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=self.beta1) \ .minimize(self.d_loss, var_list=self.d_vars) self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=self.beta1) \ .minimize(self.g_loss, var_list=self.g_vars) def init_model(self): """ This is a wrapper function for initiatilising the model, for example initialisation weights, or loading weights from a past checkpoint Returns ----------- void - updates the model with initialisation variables """ tf.global_variables_initializer().run() self.writer = FileWriter(mantra_model=self) self.summary = tf.summary.merge_all() self.writer.add_graph(self.sess.graph) def run(self): """ Runs the training. """ # Build and initialize self.build_model() self.init_model() # Ready data self.batches_per_epoch = min(self.data.X.shape[0], self.data.Y.shape[0]) // self.batch_size # Results Dict np.random.seed(int(time.time())) # random seed for training for epoch in range(self.epochs): self.epoch_start_time = time.time() for iter in range(self.batches_per_epoch): self.gradient_update(iter) self.end_of_epoch_update(epoch) ModelCheckpoint(mantra_model=self, session=self.sess) if self.task: EvaluateTask(mantra_model=self) StoreTrial(mantra_model=self, epoch=epoch) self.end_of_epoch_message(epoch=epoch, message=str(time.time() - self.epoch_start_time)) self.end_of_training_update() def gradient_update(self, iter): """ Updates the parameters with a single gradient update Parameters ---------- iter - int The iteration number Returns ---------- void - updates parameters """ #lr = self.lr if iter < self.epoch_step else self.lr*(self.epochs-iter)/(self.epochs-self.epoch_step) # Discriminator Update x = self.data.X[iter*self.batch_size:(iter+1)*self.batch_size] y = self.data.Y[iter*self.batch_size:(iter+1)*self.batch_size] # Update G network and record fake outputs fake_A, fake_B, _, summary_str = self.sess.run( [self.fake_A, self.fake_B, self.g_optim, self.g_sum], feed_dict={self.real_A: x, self.real_B: y}) self.writer.add_summary(summary_str, iter) # Update D network _, summary_str = self.sess.run( [self.d_optim, self.d_sum], feed_dict={self.real_A: x, self.real_B: y, self.fake_A_sample: fake_A, self.fake_B_sample: fake_B}) self.writer.add_summary(summary_str, iter) def end_of_epoch_update(self, epoch): fake_A, fake_B = self.sess.run( [self.fake_A, self.fake_B], feed_dict={self.real_A: self.data.X[:9], self.real_B: self.data.Y[:9]} ) size_figure_grid = 3 fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32)) for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) for k in range(size_figure_grid*size_figure_grid): i = k // size_figure_grid j = k % size_figure_grid ax[i, j].cla() ax[i, j].imshow(np.reshape(((self.data.X[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels))) label = 'Epoch {0}'.format(epoch) fig.text(0.5, 0.04, label, ha='center') SavePlot(mantra_model=self, plt=plt, plt_name='real_A_%s.png' % epoch) plt.close() size_figure_grid = 3 fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32)) for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) for k in range(size_figure_grid*size_figure_grid): i = k // size_figure_grid j = k % size_figure_grid ax[i, j].cla() ax[i, j].imshow(np.reshape(((self.data.Y[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels))) label = 'Epoch {0}'.format(epoch) fig.text(0.5, 0.04, label, ha='center') SavePlot(mantra_model=self, plt=plt, plt_name='real_B_%s.png' % epoch) plt.close() size_figure_grid = 3 fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32)) for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) for k in range(size_figure_grid*size_figure_grid): i = k // size_figure_grid j = k % size_figure_grid ax[i, j].cla() ax[i, j].imshow(np.reshape(((fake_A[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels))) label = 'Epoch {0}'.format(epoch) fig.text(0.5, 0.04, label, ha='center') SavePlot(mantra_model=self, plt=plt, plt_name='A_%s.png' % epoch) plt.close() size_figure_grid = 3 fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(32, 32)) for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) for k in range(size_figure_grid*size_figure_grid): i = k // size_figure_grid j = k % size_figure_grid ax[i, j].cla() ax[i, j].imshow(np.reshape(((fake_B[k]+1)*127.5).astype(np.uint8), (self.data.image_shape[0], self.data.image_shape[0], self.data.n_color_channels))) label = 'Epoch {0}'.format(epoch) fig.text(0.5, 0.04, label, ha='center') SavePlot(mantra_model=self, plt=plt, plt_name='B_%s.png' % epoch) plt.close()
Code
cyclegan / notebook.ipynb
1 lines | 830 bytes