Source code for Tars.models.vae

import numpy as np
import theano
import theano.tensor as T
import lasagne
from progressbar import ProgressBar

from ..utils import log_mean_exp, tolist
from ..distributions.estimate_kl import analytical_kl, get_prior
from ..models.model import Model


[docs]class VAE(Model): def __init__(self, q, p, prior=None, n_batch=100, optimizer=lasagne.updates.adam, optimizer_params={}, clip_grad=None, max_norm_constraint=None, train_iw=False, test_iw=True, iw_alpha=0, seed=1234): super(VAE, self).__init__(n_batch=n_batch, seed=seed) self.q = q self.p = p if prior: self.prior = prior else: self.prior = get_prior(self.q) # set prior distribution mode if self.prior.__class__.__name__ == "MultiPriorDistributions": if self.prior.prior is None: self.prior.prior = get_prior(self.q.distributions[-1]) self.prior_mode = "MultiPrior" else: self.prior_mode = "Normal" self.train_iw = train_iw self.test_iw = test_iw self.optimizer = optimizer self.optimizer_params = optimizer_params self.clip_grad = clip_grad self.max_norm_constraint = max_norm_constraint self.iw_alpha = iw_alpha # set inputs x = self.q.inputs l = T.iscalar("l") k = T.iscalar("k") annealing_beta = T.fscalar("beta") # training if self.train_iw: inputs = x + [l, k] lower_bound, loss, params = self._vr_bound(x, l, k, self.iw_alpha, False) else: inputs = x + [l, annealing_beta] lower_bound, loss, params = self._elbo(x, l, annealing_beta, False) lower_bound = T.mean(lower_bound, axis=0) updates = self._get_updates(loss, params, self.optimizer, self.optimizer_params, self.clip_grad, self.max_norm_constraint) self.lower_bound_train = theano.function(inputs=inputs, outputs=lower_bound, updates=updates, on_unused_input='ignore') # test if self.test_iw: inputs = x + [l, k] lower_bound, _, _ = self._vr_bound(x, l, k, 0, True) else: inputs = x + [l] lower_bound, _, _ = self._elbo(x, l, 1, True) lower_bound = T.sum(lower_bound, axis=1) self.lower_bound_test = theano.function(inputs=inputs, outputs=lower_bound, on_unused_input='ignore')
[docs] def train(self, train_set, l=1, k=1, annealing_beta=1, verbose=False): n_x = train_set[0].shape[0] nbatches = n_x // self.n_batch lower_bound_all = [] if verbose: pbar = ProgressBar(maxval=nbatches).start() for i in range(nbatches): start = i * self.n_batch end = start + self.n_batch batch_x = [_x[start:end] for _x in train_set] if self.train_iw: _x = batch_x + [l, k] lower_bound = self.lower_bound_train(*_x) else: _x = batch_x + [l, annealing_beta] lower_bound = self.lower_bound_train(*_x) lower_bound_all.append(np.array(lower_bound)) if verbose: pbar.update(i) lower_bound_all = np.mean(lower_bound_all, axis=0) return lower_bound_all
[docs] def test(self, test_set, l=1, k=1, n_batch=None, verbose=True): if n_batch is None: n_batch = self.n_batch n_x = test_set[0].shape[0] nbatches = n_x // n_batch lower_bound_all = [] if verbose: pbar = ProgressBar(maxval=nbatches).start() for i in range(nbatches): start = i * n_batch end = start + n_batch batch_x = [_x[start:end] for _x in test_set] if self.test_iw: _x = batch_x + [l, k] lower_bound = self.lower_bound_test(*_x) else: _x = batch_x + [l] lower_bound = self.lower_bound_test(*_x) lower_bound_all = np.r_[lower_bound_all, lower_bound] if verbose: pbar.update(i) return lower_bound_all
def _elbo(self, x, l, annealing_beta, deterministic=False): """ The evidence lower bound (original VAE) [Kingma+ 2013] Auto-Encoding Variational Bayes """ kl_divergence = analytical_kl(self.q, self.prior, given=[x, None], deterministic=deterministic) z = self.q.sample_given_x(x, repeat=l, deterministic=deterministic) inverse_z = self._inverse_samples(z) log_likelihood =\ self.p.log_likelihood_given_x(inverse_z, deterministic=deterministic) lower_bound = T.stack([-kl_divergence, log_likelihood], axis=-1) loss = -T.mean(log_likelihood - annealing_beta * kl_divergence) q_params = self.q.get_params() p_params = self.p.get_params() params = q_params + p_params if self.prior_mode == "MultiPrior": params += self.prior.get_params() return lower_bound, loss, params def _vr_bound(self, x, l, k, iw_alpha=0, deterministic=False): """ Variational Renyi bound [Li+ 2016] Renyi Divergence Variational Inference [Burda+ 2015] Importance Weighted Autoencoders """ q_samples = self.q.sample_given_x(x, repeat=l * k, deterministic=deterministic) log_iw = self._log_importance_weight(q_samples, deterministic=deterministic) log_iw_matrix = log_iw.reshape((x[0].shape[0] * l, k)) if iw_alpha == 1: log_likelihood = T.mean( log_iw_matrix, axis=1) elif iw_alpha == -np.inf: log_likelihood = T.max( log_iw_matrix, axis=1) else: log_iw_matrix = log_iw_matrix * (1 - iw_alpha) log_likelihood = log_mean_exp( log_iw_matrix, axis=1, keepdims=True) / (1 - iw_alpha) log_likelihood = log_likelihood.reshape((x[0].shape[0], l)) log_likelihood = T.mean(log_likelihood, axis=1) loss = -T.mean(log_likelihood) q_params = self.q.get_params() p_params = self.p.get_params() params = q_params + p_params if self.prior_mode == "MultiPrior": params += self.prior.get_params() return log_likelihood, loss, params def _log_importance_weight(self, samples, deterministic=False): """ inputs : [[x,y,...],z1,z2,...,zn] outputs : log p(x,z1,z2,...,zn|y,...)/q(z1,z2,...,zn|x,y,...) """ log_iw = 0 """ log q(z1,z2,...,zn|x,y,...) samples : [[x,y,...],z1,z2,...,zn] """ q_log_likelihood =\ self.q.log_likelihood_given_x(samples, deterministic=deterministic) """ log p(x|z1,z2,...,zn,y,...) inverse_samples : [[zn,y,...],zn-1,...,x] """ p_samples, prior_samples = self._inverse_samples( samples, return_prior=True) p_log_likelihood =\ self.p.log_likelihood_given_x(p_samples, deterministic=deterministic) log_iw += p_log_likelihood - q_log_likelihood if self.prior_mode == "MultiPrior": log_iw += self.prior.log_likelihood_given_x(prior_samples) else: log_iw += self.prior.log_likelihood(prior_samples) return log_iw def _inverse_samples(self, samples, return_prior=False): """ inputs : [[x,y],z1,z2,...zn] outputs : p_samples, prior_samples if mode is "Normal" : [[zn,y],zn-1,...x], zn elif mode is "MultiPrior" : [z1, x], [[zn,y],zn-1,...z1] """ inverse_samples = samples[::-1] inverse_samples[0] = [inverse_samples[0]] + inverse_samples[-1][1:] inverse_samples[-1] = inverse_samples[-1][0] if self.prior_mode == "Normal": p_samples = inverse_samples prior_samples = samples[-1] elif self.prior_mode == "MultiPrior": p_samples = [tolist(inverse_samples[-2]), inverse_samples[-1]] prior_samples = inverse_samples[:-1] else: raise Exception("You should set prior_mode to 'Normal' or" "'MultiPrior', got %s." % self.prior_mode) if return_prior: return p_samples, prior_samples else: return p_samples