Source code for Tars.models.gan

import numpy as np
import theano
import theano.tensor as T
import lasagne
from progressbar import ProgressBar

from ..models.model import Model
from ..distributions.distribution_samples import mean_sum_samples
from ..utils import epsilon


[docs]class GAN(Model): def __init__(self, p, d, n_batch=100, p_optimizer=lasagne.updates.adam, d_optimizer=lasagne.updates.adam, p_optimizer_params={}, d_optimizer_params={}, p_critic=lambda gt: -T.log(gt+epsilon()), d_critic=lambda t, gt: -T.log(t+epsilon()) - T.log(1-gt+epsilon()), p_clip_param=None, d_clip_param=None, p_clip_grad=None, d_clip_grad=None, p_max_norm_constraint=None, d_max_norm_constraint=None, l1_lambda=0, seed=1234): super(GAN, self).__init__(n_batch=n_batch, seed=seed) self.p = p self.d = d self.hidden_dim = self.p.get_input_shape()[0][1:] self.l1_lambda = l1_lambda # for pix2pix # set inputs z = self.p.inputs x = self.d.inputs # set critic self.p_critic = p_critic self.d_critic = d_critic # training inputs = z[:1] + x loss, params = self._loss(z, x, False) p_updates = self._get_updates(loss[0], params[0], p_optimizer, p_optimizer_params, p_clip_param, p_clip_grad, p_max_norm_constraint) d_updates = self._get_updates(loss[1], params[1], d_optimizer, d_optimizer_params, d_clip_param, d_clip_grad, d_max_norm_constraint) self.p_train = theano.function(inputs=inputs, outputs=loss, updates=p_updates, on_unused_input='ignore') self.d_train = theano.function(inputs=inputs, outputs=loss, updates=d_updates, on_unused_input='ignore') # test inputs = z[:1] + x loss, _ = self._loss(z, x, True) self.test = theano.function(inputs=inputs, outputs=loss, on_unused_input='ignore') def _loss(self, z, x, deterministic=False): # gx~p(x|z,y,...) gx = self.p.sample_mean_given_x( z, deterministic=deterministic)[-1] # t~d(t|x,y,...) t = self.d.sample_mean_given_x( x, deterministic=deterministic)[-1] # gt~d(t|gx,y,...) gt = self.d.sample_mean_given_x( [gx] + x[1:], deterministic=deterministic)[-1] p_loss = mean_sum_samples(self.p_critic(gt)).mean() d_loss = mean_sum_samples(self.d_critic(t, gt)).mean() if deterministic is False and len(z) > 1: p_loss +=\ self.l1_lambda * mean_sum_samples(T.abs_(x[0]-gx)).mean() p_params = self.p.get_params() d_params = self.d.get_params() return [p_loss, d_loss], [p_params, d_params]
[docs] def train(self, train_set, freq=1, verbose=False): n_x = len(train_set[0]) nbatches = n_x // self.n_batch z_dim = (self.n_batch,) + self.hidden_dim loss_all = [] if verbose: pbar = ProgressBar(maxval=nbatches).start() for i in range(nbatches): start = i * self.n_batch end = start + self.n_batch batch_x = [_x[start:end] for _x in train_set] batch_z =\ self.rng.uniform(-1., 1., size=z_dim).astype(batch_x[0].dtype) _x = [batch_z] + batch_x loss = self.p_train(*_x) loss = self.d_train(*_x) loss_all.append(np.array(loss)) if verbose: pbar.update(i) loss_all = np.mean(loss_all, axis=0) return loss_all
[docs] def gan_test(self, test_set, n_batch=None, verbose=False): if n_batch is None: n_batch = self.n_batch n_x = test_set[0].shape[0] nbatches = n_x // n_batch z_dim = (n_batch,) + self.hidden_dim loss_all = [] if verbose: pbar = ProgressBar(maxval=nbatches).start() for i in range(nbatches): start = i * n_batch end = start + n_batch batch_x = [_x[start:end] for _x in test_set] batch_z =\ self.rng.uniform(-1., 1., size=z_dim).astype(batch_x[0].dtype) _x = [batch_z] + batch_x loss = self.test(*_x) loss_all.append(np.array(loss)) if verbose: pbar.update(i) loss_all = np.mean(loss_all, axis=0) return loss_all