Source code for Tars.distributions.estimate_kl

import math
import theano
import theano.tensor as T

from ..utils import epsilon, tolist
from .distribution_samples import (
    UnitGaussianSample,
    UnitBernoulliSample,
    UnitCategoricalSample,
    UnitBetaSample,
    UnitDirichletSample,
    UnitGammaSample,
)


[docs]def analytical_kl(q1, q2, given, deterministic=False): try: [x1, x2] = given except: raise ValueError("The length of given list must be 2, " "got %d" % len(given)) q1_class = q1.__class__.__name__ q2_class = q2.__class__.__name__ if q1_class == "Gaussian" and q2_class == "UnitGaussianSample": mean, var = q1.fprop(x1, deterministic=deterministic) return gauss_unitgauss_kl(mean, var) elif q1_class == "Gaussian" and q2_class == "Gaussian": mean1, var1 = q1.fprop(x1, deterministic=deterministic) mean2, var2 = q2.fprop(x2, deterministic=deterministic) return gauss_gauss_kl(mean1, var1, mean2, var2) elif q1_class == "Bernoulli" and q2_class == "UnitBernoulliSample": mean = q1.fprop(x1, deterministic=deterministic) output = mean * (T.log(mean + epsilon()) + T.log(2)) +\ (1 - mean) * (T.log(1 - mean + epsilon()) + T.log(2)) return T.sum(output, axis=1) elif q1_class == "Categorical" and q2_class == "UnitCategoricalSample": mean = q1.fprop(x1, deterministic=deterministic) output = mean * (T.log(mean + epsilon()) + T.log(q1.k)) return T.sum(output, axis=1) elif q1_class == "Kumaraswamy" and q2_class == "UnitBetaSample": """ [Naelisnick+ 2016] Deep Generative Models with Stick-Breaking Priors """ M = 10 euler_gamma = 0.57721 a, b = q1.fprop(x1, deterministic=deterministic) def taylor(m, a, b): return 1. / (m + a * b) * q2._beta_func(m / a, b) kl, _ = theano.scan(fn=taylor, sequences=T.arange(1, M + 1), non_sequences=[a, b]) kl = T.sum(kl, axis=0) kl *= (q2.beta - 1) * b kl += ((a - q2.alpha) / a + epsilon()) *\ (-euler_gamma - psi(b) - 1. / (b + epsilon())) kl += T.log(a * b + epsilon()) +\ T.log(q2._beta_func(q2.alpha, q2.beta) + epsilon()) kl += -(b - 1) / (b + epsilon()) return T.sum(kl, axis=1) elif q1_class == "Gamma" and q2_class == "UnitGammaSample": """ https://arxiv.org/pdf/1611.01437.pdf """ alpha1, beta1 = q1.fprop(x1, deterministic=deterministic) alpha2 = T.ones_like(alpha1) beta2 = T.ones_like(beta1) output = (alpha1 - alpha2) * psi(alpha1) output += -T.gammaln(alpha1) + T.gammaln(alpha2) output += alpha2 * (T.log(beta1 + epsilon()) - T.log(beta2 + epsilon())) output += alpha1 * (beta2 - beta1) / (beta1 + epsilon()) return T.sum(output, axis=1) elif q1_class == "Beta" and q2_class == "UnitBetaSample": """ http://bariskurt.com/kullback-leibler-divergence\ -between-two-dirichlet-and-beta-distributions/ """ alpha1, beta1 = q1.fprop(x1, deterministic=deterministic) alpha2 = T.ones_like(alpha1) * q2.alpha beta2 = T.ones_like(beta1) * q2.beta output = T.gammaln(alpha1 + beta1) -\ T.gammaln(alpha2 + beta2) -\ (T.gammaln(alpha1) + T.gammaln(beta1)) +\ (T.gammaln(alpha2) + T.gammaln(beta2)) +\ (alpha1 - alpha2) * (psi(alpha1) - psi(alpha1 + beta1)) +\ (beta1 - beta2) * (psi(beta1) - psi(alpha1 + beta1)) return T.sum(output, axis=1) elif q1_class == "Dirichlet" and q2_class == "UnitDirichletSample": """ http://bariskurt.com/kullback-leibler-divergence\ -between-two-dirichlet-and-beta-distributions/ """ alpha1 = q1.fprop(x1, deterministic=deterministic) alpha1 = alpha1.reshape((alpha1.shape[0], alpha1.shape[1] / q1.k, q1.k)) alpha2 = T.ones_like(alpha1) * q2.alpha output = T.gammaln(T.sum(alpha1, axis=-1)) -\ T.gammaln(T.sum(alpha2, axis=-1)) -\ T.sum(T.gammaln(alpha1), axis=-1) +\ T.sum(T.gammaln(alpha2), axis=-1) +\ T.sum((alpha1 - alpha2) * (psi(alpha1) - psi(T.sum(alpha1, axis=-1, keepdims=True))), axis=-1) return T.sum(output, axis=1) elif (q1_class == "MultiDistributions") and ( q2_class == "MultiPriorDistributions"): """ PixelVAE https://arxiv.org/abs/1611.05013 """ all_kl = 0 for i, q, p in zip(range(len(q1.distributions[:-1])), q1.distributions[:-1], reversed(q2.distributions)): if i == 0: _x = x1 else: _x = q1.sample_mean_given_x(x1, layer_id=i - 1)[-1] z = q1.sample_given_x(x1, layer_id=i + 1)[-1] kl = analytical_kl(q, p, given=[tolist(_x), tolist(z)]) all_kl += kl _x = q1.sample_mean_given_x(x1, layer_id=-2)[-1] kl = analytical_kl( q1.distributions[-1], q2.prior, given=[tolist(_x), None]) all_kl += kl return all_kl elif q1_class == "MultiDistributions": if len(q1.distributions) >= 2: _x1 = q1.sample_given_x(x1, layer_id=-2)[-1] else: _x1 = x1 return analytical_kl(q1.distributions[-1], q2, given=[tolist(_x1), x2], deterministic=deterministic) raise Exception("You cannot use this distribution as q or prior, " "got %s and %s." % (q1_class, q2_class))
[docs]def gaussian_like(x, mean, var): c = - 0.5 * math.log(2 * math.pi) _var = var + epsilon() # avoid NaN return c - T.log(_var) / 2 - (x - mean)**2 / (2 * _var)
[docs]def gauss_unitgauss_kl(mean, var): return -0.5 * T.sum(1 + T.log(var) - mean**2 - var, axis=1)
[docs]def gauss_gauss_kl(mean1, var1, mean2, var2): _var2 = var2 + epsilon() # avoid NaN _kl = T.log(var2) - T.log(var1) \ + (var1 + (mean1 - mean2)**2) / _var2 - 1 return 0.5 * T.sum(_kl, axis=1)
[docs]def psi(b): # Because T.psi haven't implemented yet. output = T.log(b + epsilon()) - 1. / (2 * b + epsilon()) -\ 1. / (12 * b**2 + epsilon()) return output
[docs]def get_prior(q): q_class = q.__class__.__name__ if q_class == "Gaussian": return UnitGaussianSample() elif q_class == "Bernoulli": return UnitBernoulliSample() elif q_class == "Categorical": return UnitCategoricalSample(q.k) elif q_class == "Kumaraswamy" or q_class == "Beta": return UnitBetaSample() elif q_class == "Dirichlet": return UnitDirichletSample(q.k) elif q_class == "Gamma": return UnitGammaSample() elif q_class == "MultiDistributions": return get_prior(q.distributions[-1]) raise Exception("You cannot use this distribution as q, " "got %s." % q_class)