Source code for Tars.models.jmvae

from copy import copy

import theano
import theano.tensor as T
import lasagne

from ..utils import tolist, log_mean_exp
from ..distributions.estimate_kl import analytical_kl
from . import VAE


[docs]class JMVAE(VAE): def __init__(self, q, p, prior=None, n_batch=100, optimizer=lasagne.updates.adam, optimizer_params={}, clip_grad=None, max_norm_constraint=None, train_iw=False, test_iw=True, seed=1234): super(JMVAE, self).__init__(q, p, prior=prior, n_batch=n_batch, optimizer=optimizer, optimizer_params=optimizer_params, clip_grad=clip_grad, max_norm_constraint=max_norm_constraint, train_iw=train_iw, test_iw=test_iw, iw_alpha=0, seed=seed) def _set_test(self, type_p="marginal", missing=False, index=[0], sampling_n=1, missing_resample=False): # set inputs x = self.q.inputs l = T.iscalar("l") k = T.iscalar("k") if type_p == "joint": if self.test_iw: inputs = x + [l, k] lower_bound, _, _ = self._vr_bound(x, l, k, 0, True) else: inputs = x + [l] lower_bound, _, _ = self._elbo(x, l, 1, True) lower_bound = T.sum(lower_bound, axis=1) else: inputs = x + [l, k] lower_bound = self._vr_bound_test( x, l, k, index, type_p, missing, sampling_n, missing_resample) self.lower_bound_test = theano.function(inputs=inputs, outputs=lower_bound, on_unused_input='ignore')
[docs] def test(self, test_set, l=1, k=1, index=[0], sampling_n=1, missing_resample=False, type_p="joint", missing=False, n_batch=None, verbose=True): self._set_test(type_p, missing, index, sampling_n, missing_resample) return super(JMVAE, self).test(test_set, l, k, n_batch, verbose)
def _elbo(self, x, l, annealing_beta, deterministic=False): kl_divergence = analytical_kl(self.q, self.prior, given=[x, None], deterministic=deterministic) z = self.q.sample_given_x(x, repeat=l, deterministic=deterministic) log_likelihood_all = [] p_params = [] for i, p in enumerate(self.p): inverse_z = self._inverse_samples(self._select_input(z, [i])) log_likelihood = p.log_likelihood_given_x( inverse_z, deterministic=deterministic) log_likelihood_all.append(log_likelihood) p_params += p.get_params() lower_bound = T.stack([-kl_divergence] + log_likelihood_all, axis=-1) loss = -T.mean(sum(log_likelihood_all) - annealing_beta * kl_divergence) q_params = self.q.get_params() params = q_params + p_params if self.prior_mode == "MultiPrior": params += self.prior.get_params() return lower_bound, loss, params def _vr_bound(self, x, l, k, iw_alpha=0, deterministic=False): q_samples = self.q.sample_given_x( x, repeat=l * k, deterministic=deterministic) log_iw = self._log_importance_weight(q_samples, deterministic=deterministic) log_iw_matrix = log_iw.reshape((x[0].shape[0] * l, k)) log_likelihood = log_mean_exp( log_iw_matrix, axis=1, keepdims=True) log_likelihood = log_likelihood.reshape((x[0].shape[0], l)) log_likelihood = T.mean(log_likelihood, axis=1) loss = -T.mean(log_likelihood) p_params = [] for i, p in enumerate(self.p): p_params += p.get_params() q_params = self.q.get_params() params = q_params + p_params if self.prior_mode == "MultiPrior": params += self.prior.get_params() return log_likelihood, loss, params def _vr_bound_test(self, x, l, k, index=[0], type_p="marginal", missing=False, sampling_n=1, missing_resample=False): """ Paramaters ---------- x : TODO l : TODO k : TODO type_p : {'conditional', 'marginal'} Specifies the type of the log likelihood. Returns -------- log_marginal_estimate : array, shape (n_samples) Estimated log likelihood. """ n_x = x[0].shape[0] rep_x = [T.extra_ops.repeat(_x, l * k, axis=0) for _x in x] if type_p not in ['marginal', 'conditional']: raise ValueError("type_p must be one of {" "'marginal', 'conditional'}, got %s." % type_p) if missing: if type_p == "marginal": _rep_x = self._select_input([rep_x], index, set_zeros=True)[0] samples = self.q.sample_given_x(_rep_x, deterministic=True) samples = self._select_input(samples, inputs=rep_x) log_iw = self._log_mg_missing_importance_weight( samples, index, deterministic=True) elif type_p == "conditional": # rep_x:[x0,x1] -> _rep_x:[0,x1] rv_index = self._reverse_index(index) _rep_x = self._select_input( [rep_x], rv_index, set_zeros=True)[0] if missing_resample: for _ in range(sampling_n - 1): samples = self.q.sample_given_x(_rep_x, deterministic=True)[-1] for j in index: _rep_x[j] = self.p[j].sample_given_x([samples])[-1] samples = self.q.sample_given_x(_rep_x, deterministic=True) samples = self._select_input(samples, inputs=rep_x) log_iw = self._log_cd_importance_weight( samples, index, deterministic=True) else: samples = self.q.sample_given_x(rep_x, deterministic=True) if type_p == "marginal": log_iw = self._log_selected_importance_weight( samples, index, deterministic=True) else: log_iw = self._log_cd_importance_weight( samples, index, deterministic=True) log_iw_matrix = T.reshape(log_iw, (n_x * l, k)) log_likelihood = log_mean_exp( log_iw_matrix, axis=1, keepdims=True) log_likelihood = log_likelihood.reshape((x[0].shape[0], l)) log_likelihood = T.mean(log_likelihood, axis=1) return log_likelihood def _log_importance_weight(self, samples, deterministic=True): """ Paramaters ---------- samples : list [[x0,x1,...],z1,z2,...,zn] Returns ------- log_iw : array, shape (n_samples) Estimated log likelihood. log p(x0,x1,...,z1,z2,...,zn)/q(z1,z2,...,zn|x0,x1,...) """ log_iw = 0 # log q(z1,z2,...,zn|x0,x1,...) # samples : [[x0,x1,...],z1,z2,...,zn] q_log_likelihood = self.q.log_likelihood_given_x( samples, deterministic=deterministic) # log p(x|z1) p_log_likelihood_all = [] for i, p in enumerate(self.p): p_samples, prior_samples = self._inverse_samples( self._select_input(samples, [i]), return_prior=True) p_log_likelihood = p.log_likelihood_given_x( p_samples, deterministic=deterministic) p_log_likelihood_all.append(p_log_likelihood) log_iw += sum(p_log_likelihood_all) - q_log_likelihood # log p(z1,..,zn) if self.prior_mode == "MultiPrior": log_iw += self.prior.log_likelihood_given_x(prior_samples) else: log_iw += self.prior.log_likelihood(prior_samples) return log_iw def _log_selected_importance_weight(self, samples, index=[0], deterministic=True): """ Paramaters ---------- samples : list [[x0,x1,...],z1,z2,...,zn] index : list Returns ------- log_iw : array, shape (n_samples*k) Estimated log likelihood. log p(x[index],z1,z2,...,zn)/q(z1,z2,...,zn|x0,x1,...) """ log_iw = 0 # log q(z1,z2,...,zn|x0,x1,...) # samples : [[x0,x1,...],z1,z2,...,zn] q_log_likelihood = self.q.log_likelihood_given_x( samples, deterministic=deterministic) # log p(x[index]|z1) p_log_likelihood_all = [] for i in index: p_samples, prior_samples = self._inverse_samples( self._select_input(samples, [i]), return_prior=True) p_log_likelihood = self.p[i].log_likelihood_given_x( prior_samples, deterministic=deterministic) p_log_likelihood_all.append(p_log_likelihood) log_iw += sum(p_log_likelihood_all) - q_log_likelihood # log p(z1,..,zn) if self.prior_mode == "MultiPrior": log_iw += self.prior.log_likelihood_given_x(prior_samples) else: log_iw += self.prior.log_likelihood(prior_samples) return log_iw def _log_mg_missing_importance_weight(self, samples, index=[0], deterministic=True): """ Paramaters ---------- samples : list [[x0,x1,...],z1,z2,...,zn] Returns ------- log_iw : array, shape (n_samples*k) Estimated log likelihood. log p(x[index],z1,z2,...,zn)/q(z1,z2,...,zn|x[index]) """ log_iw = 0 # samples : [[x0,0,...],z1,z2,...,zn] samples = self._select_input(samples, index, set_zeros=True) # log q(z1,z2,...,zn|x0,0,...) # samples : [[x0,0,...],z1,z2,...,zn] q_log_likelihood = self.q.log_likelihood_given_x( samples, deterministic=deterministic) # log p(x[index]|z1) p_log_likelihood_all = [] for i in index: p_samples, prior_samples = self._inverse_samples( self._select_input(samples, [i]), return_prior=True) p_log_likelihood = self.p[i].log_likelihood_given_x( prior_samples, deterministic=deterministic) p_log_likelihood_all.append(p_log_likelihood) log_iw += sum(p_log_likelihood_all) - q_log_likelihood # log p(z1,..,zn) if self.prior_mode == "MultiPrior": log_iw += self.prior.log_likelihood_given_x(prior_samples) else: log_iw += self.prior.log_likelihood(prior_samples) return log_iw def _log_cd_importance_weight(self, samples, index=[0], deterministic=True): """ Paramaters ---------- samples : list [[x0,x1,...],z1,z2,...,zn] Returns ------- log_iw : array, shape (n_samples*k) Estimated log likelihood. log p(x[index]|z1,z2,...,zn) """ log_iw = 0 # log p(x[index]|z1) p_log_likelihood_all = [] for i in index: p_samples, prior_samples = self._inverse_samples( self._select_input(samples, [i]), return_prior=True) p_log_likelihood = self.p[i].log_likelihood_given_x( p_samples, deterministic=deterministic) p_log_likelihood_all.append(p_log_likelihood) log_iw += sum(p_log_likelihood_all) # log p(z1,,z2,...|zn) if self.prior_mode == "MultiPrior": log_iw += self.prior.log_likelihood_given_x(prior_samples, add_prior=False) return log_iw def _select_input(self, samples, index=[0], set_zeros=False, inputs=None): """ Paramaters ---------- samples : list [[x,y,...],z1,z2,....] index : list Selects an input from [x,y...]. set_zero :TODO inputs : list The inputs which you want to replace from [x,y,...]. Returns ---------- _samples : list if i=[0], then _samples = [[x],z1,z2,....] if i=[1], then _samples = [[y],z1,z2,....] if i=[0,1], then _samples = [[x,y],z1,z2,....] if i=[0] and set_zeros=True, then _samples = [[x,0],z1,z2,....] """ _samples = copy(samples) if inputs: _samples[0] = tolist(inputs) else: _samples_inputs = copy(_samples[0]) if set_zeros: for i in self._reverse_index(index): _samples_inputs[i] = T.zeros_like(_samples_inputs[i]) _samples[0] = _samples_inputs else: _input_samples = [] for i in index: _input_samples.append(_samples[0][i]) _samples[0] = _input_samples return _samples def _reverse_index(self, index=[0]): N = len(self.p) return list(set(range(N)) - set(index))