''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import numpy as np
from typing import Tuple
import torch
from torch.functional import F
from torch.distributions import constraints
from torch.distributions import Normal, Categorical, OneHotCategorical
import pyro
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.kl import register_kl, kl_divergence
from itertools import groupby
[docs]
def all_equal(iterable):
''' Define a function to check if all elements in an iterable are equal'''
g = groupby(iterable)
return next(g, True) and not next(g, False)
[docs]
def exportable_broadcast(tensor1 : torch.Tensor, tensor2 : torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
''' Broadcast tensors to the same shape using onnx exportable operators '''
if len(tensor1.shape) < len(tensor2.shape):
tensor2, tensor1 = exportable_broadcast(tensor2, tensor1)
else:
shape1 = tensor1.shape
shape2 = tensor2.shape
if len(shape1) == len(shape2):
final_shape = [max(s1, s2) for s1, s2 in zip(shape1, shape2)]
tensor1 = tensor1.expand(*final_shape)
tensor2 = tensor2.expand(*final_shape)
else:
tensor2 = tensor2.expand(*shape1)
return tensor1, tensor2
[docs]
class ReviveDistributionMixin:
'''Define revive distribution API'''
@property
def mode(self,):
'''return the most likely sample of the distributions'''
raise NotImplementedError
@property
def std(self):
'''return the standard deviation of the distributions'''
raise NotImplementedError
[docs]
def sample_with_logprob(self, sample_shape=torch.Size()):
sample = self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
return sample, self.log_prob(sample)
[docs]
class ReviveDistribution(pyro.distributions.TorchDistribution, ReviveDistributionMixin):
pass
[docs]
class ExportableNormal(Normal):
def __init__(self, loc, scale, validate_args):
""" Use exportable_broadcast() aligns the shape of the tensor """
self.loc, self.scale = exportable_broadcast(loc, scale)
batch_shape = self.loc.size()
super(Normal, self).__init__(batch_shape, validate_args=validate_args)
[docs]
class ExportableCategorical(Categorical):
[docs]
def log_prob(self, value):
""" Use exportable_broadcast() aligns the shape of the tensor """
if self._validate_args:
self._validate_sample(value)
value = value.long().unsqueeze(-1)
value, log_pmf = exportable_broadcast(value, self.logits)
value = value[..., :1]
return log_pmf.gather(-1, value).squeeze(-1)
[docs]
class DiagnalNormal(ReviveDistribution):
def __init__(self, loc, scale, validate_args=False):
self.base_dist = ExportableNormal(loc, scale, validate_args)
batch_shape = torch.Size(loc.shape[:-1])
event_shape = torch.Size([loc.shape[-1]])
super(DiagnalNormal, self).__init__(batch_shape, event_shape, validate_args)
[docs]
def sample(self, sample_shape=torch.Size()):
return self.base_dist.sample(sample_shape)
[docs]
def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(sample_shape)
[docs]
def log_prob(self, sample):
log_prob = self.base_dist.log_prob(sample)
return torch.sum(log_prob, dim=-1)
[docs]
def entropy(self):
entropy = self.base_dist.entropy()
return torch.sum(entropy, dim=-1)
[docs]
def shift(self, mu_shift):
'''shift the distribution, useful in local mode transition'''
return DiagnalNormal(self.base_dist.loc + mu_shift, self.base_dist.scale)
@property
def mode(self):
return self.base_dist.mean
@property
def std(self):
return self.base_dist.scale
[docs]
class DiscreteLogistic(ReviveDistribution):
r"""
Model discrete variable with Logistic distribution, inspired from:
https://github.com/openai/vdvae/blob/main/vae_helpers.py
As far as I know, the trick was proposed in:
Salimans, Tim, Andrej Karpathy, Xi Chen, and Diederik P. Kingma
"Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications." arXiv preprint arXiv:1701.05517 (2017).
:param loc: Location parameter, assert it have been normalized to [-1, 1]
:param scale: Scale parameter.
:param num: Number of possible value for each dimension.
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True
def __init__(self, loc, scale, num, *, validate_args=False):
self.loc, self.scale = exportable_broadcast(loc, scale)
self.num = torch.tensor(num).to(loc)
batch_shape = torch.Size(loc.shape[:-1])
event_shape = torch.Size([loc.shape[-1]])
super(DiscreteLogistic, self).__init__(batch_shape, event_shape, validate_args)
[docs]
def log_prob(self, value):
mid = value - self.loc
plus = (mid + 1 / (self.num - 1)) / self.scale
minus = (mid - 1 / (self.num - 1)) / self.scale
prob = torch.sigmoid(plus) - torch.sigmoid(minus)
log_prob_left_edge = plus - F.softplus(plus)
log_prob_right_edge = - F.softplus(minus)
z = mid / self.scale
log_prob_extreme = z - torch.log(self.scale) - 2 * F.softplus(z)
return torch.where(value < - 0.999,
log_prob_left_edge,
torch.where(value > 0.999,
log_prob_right_edge,
torch.where(prob > 1e-5,
torch.log(prob + 1e-5),
log_prob_extreme))).sum(dim=-1)
[docs]
@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
u = self.loc.new_empty(shape).uniform_()
value = self.icdf(u)
round_value = self.round(value)
return torch.clamp(round_value, -1, 1)
[docs]
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
u = self.loc.new_empty(shape).uniform_()
value = self.icdf(u)
round_value = self.round(value)
return torch.clamp(round_value, -1, 1) + value - value.detach()
[docs]
def cdf(self, value):
z = (value - self.loc) / self.scale
return torch.sigmoid(z)
[docs]
def icdf(self, value):
return self.loc + self.scale * torch.logit(value, eps=1e-5)
[docs]
def round(self, value):
value = (value + 1) / 2 * (self.num - 1)
return torch.round(value) / (self.num - 1) * 2 - 1
@property
def mode(self):
return torch.clamp(self.round(self.loc), -1, 1) + self.loc - self.loc.detach()
@property
def std(self):
return self.scale * np.pi / 3 ** 0.5
[docs]
def entropy(self):
return torch.sum(torch.log(self.scale) + 2, dim=-1)
[docs]
class Onehot(OneHotCategorical, TorchDistributionMixin, ReviveDistributionMixin):
"""Differentiable Onehot Distribution"""
has_rsample = True
_validate_args = False
def __init__(self, logits=None, validate_args=False):
self._categorical = ExportableCategorical(logits=logits, validate_args=False)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
[docs]
def rsample(self, sample_shape=torch.Size()):
# Implement straight-through estimator
# Bengio et.al. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
sample = self.sample(sample_shape)
return sample + self.probs - self.probs.detach()
@property
def mode(self):
index = torch.argmax(self.logits, dim=-1)
num_classes = self.event_shape[0]
if torch.is_tensor(num_classes):
num_classes = num_classes.item()
sample = F.one_hot(index, num_classes)
return sample + self.probs - self.probs.detach()
@property
def std(self):
return self.variance
[docs]
class GaussianMixture(pyro.distributions.MixtureOfDiagNormals, ReviveDistributionMixin):
def __init__(self, locs, coord_scale, component_logits):
self.batch_mode = (locs.dim() > 2)
assert(coord_scale.shape == locs.shape)
assert(self.batch_mode or locs.dim() == 2), \
"The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or B x K x D if doing batches)"
if not self.batch_mode:
assert(coord_scale.dim() == 2), \
"The coord_scale parameter in MixtureOfDiagNormals should be K x D dimensional"
assert(component_logits.dim() == 1), \
"The component_logits parameter in MixtureOfDiagNormals should be K dimensional"
assert(component_logits.size(-1) == locs.size(-2))
batch_shape = ()
else:
assert(coord_scale.dim() > 2), \
"The coord_scale parameter in MixtureOfDiagNormals should be B x K x D dimensional"
assert(component_logits.dim() > 1), \
"The component_logits parameter in MixtureOfDiagNormals should be B x K dimensional"
assert(component_logits.size(-1) == locs.size(-2))
batch_shape = tuple(locs.shape[:-2])
self.locs = locs
self.coord_scale = coord_scale
self.component_logits = component_logits
self.dim = locs.size(-1)
self.categorical = ExportableCategorical(logits=component_logits)
self.probs = self.categorical.probs
ReviveDistribution.__init__(self, batch_shape=torch.Size(batch_shape), event_shape=torch.Size((self.dim,)))
@property
def mode(self):
# NOTE: this is only an approximate mode
which = self.categorical.logits.max(dim=-1)[1]
which = which.unsqueeze(dim=-1).unsqueeze(dim=-1)
which_expand = which.expand(tuple(which.shape[:-1] + (self.locs.shape[-1],)))
loc = torch.gather(self.locs, -2, which_expand).squeeze(-2)
return loc
@property
def std(self):
p = self.categorical.probs
return torch.sum(self.coord_scale * p.unsqueeze(-1), dim=-2)
[docs]
def shift(self, mu_shift):
'''shift the distribution, useful in local mode transition'''
return GaussianMixture(self.locs + mu_shift.unsqueeze(dim=-2), self.coord_scale, self.component_logits)
[docs]
def entropy(self):
p = self.categorical.probs
normal = DiagnalNormal(self.locs, self.coord_scale)
entropy = normal.entropy()
return torch.sum(p * entropy, dim=-1)
[docs]
class MixDistribution(ReviveDistribution):
"""Collection of multiple distributions"""
arg_constraints = {}
def __init__(self, dists):
super().__init__()
assert all_equal([dist.batch_shape for dist in dists]), "the batch shape of all distributions should be equal"
assert all_equal([len(dist.event_shape) == 1 for dist in dists]), "the event shape of all distributions should have length 1"
self.dists = dists
self.sizes = [dist.event_shape[0] for dist in self.dists]
batch_shape = self.dists[0].batch_shape
event_shape = torch.Size((sum(self.sizes),))
super(MixDistribution, self).__init__(batch_shape, event_shape)
[docs]
def sample(self, num=torch.Size()):
samples = [dist.sample(num) for dist in self.dists]
return torch.cat(samples, dim=-1)
[docs]
def rsample(self, num=torch.Size()):
samples = [dist.rsample(num) for dist in self.dists]
return torch.cat(samples, dim=-1)
[docs]
def entropy(self):
return sum([dist.entropy() for dist in self.dists])
[docs]
def log_prob(self, x):
if type(x) == list:
return [self.dists[i].log_prob(x[i]) for i in range(len(x))]
# manually split the tensor
x = torch.split(x, self.sizes, dim=-1)
return sum([self.dists[i].log_prob(x[i]) for i in range(len(x))])
@property
def mode(self):
modes = [dist.mode for dist in self.dists]
return torch.cat(modes, dim=-1)
@property
def std(self):
stds = [dist.std for dist in self.dists]
return torch.cat(stds, dim=-1)
@property
def dists_type(self):
return [type(dist) for dist in self.dists]
[docs]
def shift(self, mu_shift):
'''shift the distribution, useful in local mode transition'''
assert all([type(dist) in [DiagnalNormal, GaussianMixture] for dist in self.dists]), \
"all the distributions should have method `shift`"
return MixDistribution([dist.shift(mu_shift) for dist in self.dists])
@register_kl(DiagnalNormal, DiagnalNormal)
def _kl_diagnalnormal_diagnalnormal(p : DiagnalNormal, q : DiagnalNormal):
kl = kl_divergence(p.base_dist, q.base_dist)
kl = torch.sum(kl, dim=-1)
return kl
@register_kl(Onehot, Onehot)
def _kl_onehot_onehot(p : Onehot, q : Onehot):
p_probs = p.probs.clamp_min(1e-30)
q_probs = q.probs.clamp_min(1e-30)
kl = (p_probs * (torch.log(p_probs) - torch.log(q_probs))).sum(dim=-1)
return kl
@register_kl(GaussianMixture, GaussianMixture)
def _kl_gmm_gmm(p : GaussianMixture, q : GaussianMixture):
samples = p.rsample()
log_p = p.log_prob(samples)
log_q = q.log_prob(samples)
return log_p - log_q
@register_kl(MixDistribution, MixDistribution)
def _kl_mix_mix(p : MixDistribution, q : MixDistribution):
assert all([type(_p) == type(_q) for _p, _q in zip(p.dists, q.dists)])
kl = 0
for _p, _q in zip(p.dists, q.dists):
kl = kl + kl_divergence(_p, _q)
return kl
@register_kl(DiscreteLogistic, DiscreteLogistic)
def _kl_discrete_logistic_discrete_logistic(p : DiscreteLogistic, q : DiscreteLogistic):
assert torch.all(p.num == q.num)
# NOTE: Cannot compute the kl divergence in the analysitical form, use 100 samples to estimate.
samples = p.sample((100,))
p_log_prob = p.log_prob(samples)
q_log_prob = q.log_prob(samples)
return torch.mean(p_log_prob - q_log_prob, dim=0)
if __name__ == '__main__':
print('-' * 50)
onehot = Onehot(torch.rand(2, 10, requires_grad=True))
print('onehot batch shape', onehot.batch_shape)
print('onehot event shape', onehot.event_shape)
print('onehot sample', onehot.sample())
print('onehot rsample', onehot.rsample())
print('onehot log prob', onehot.sample_with_logprob()[1])
print('onehot mode', onehot.mode)
print('onehot std', onehot.std)
print('onehot entropy', onehot.entropy())
_onehot = Onehot(torch.rand(2, 10, requires_grad=True))
print('onehot kl', kl_divergence(onehot, _onehot))
print('-' * 50)
mixture = GaussianMixture(
torch.rand(2, 6, 4, requires_grad=True),
torch.rand(2, 6, 4, requires_grad=True),
torch.rand(2, 6, requires_grad=True),
)
print('gmm batch shape', mixture.batch_shape)
print('gmm event shape', mixture.event_shape)
print('gmm sample', mixture.sample())
print('gmm rsample', mixture.rsample())
print('gmm log prob', mixture.sample_with_logprob()[1])
print('gmm mode', mixture.mode)
print('gmm std', mixture.std)
print('gmm entropy', mixture.entropy())
_mixture = GaussianMixture(
torch.rand(2, 6, 4, requires_grad=True),
torch.rand(2, 6, 4, requires_grad=True),
torch.rand(2, 6, requires_grad=True),
)
print('gmm kl', kl_divergence(mixture, _mixture))
print('-' * 50)
normal = DiagnalNormal(
torch.rand(2, 5, requires_grad=True),
torch.rand(2, 5, requires_grad=True)
)
print('normal batch shape', normal.batch_shape)
print('normal event shape', normal.event_shape)
print('normal sample', normal.sample())
print('normal rsample', normal.rsample())
print('normal log prob', normal.sample_with_logprob()[1])
print('normal mode', normal.mode)
print('normal std', normal.std)
print('normal entropy', normal.entropy())
_normal = DiagnalNormal(
torch.rand(2, 5, requires_grad=True),
torch.rand(2, 5, requires_grad=True)
)
print('normal kl', kl_divergence(normal, _normal))
print('-' * 50)
discrete_logic = DiscreteLogistic(
torch.rand(2, 5, requires_grad=True) * 2 - 1,
torch.rand(2, 5, requires_grad=True),
[5, 9, 17, 33, 65],
)
print('discrete logistic batch shape', discrete_logic.batch_shape)
print('discrete logistic event shape', discrete_logic.event_shape)
print('discrete logistic sample', discrete_logic.sample())
print('discrete logistic rsample', discrete_logic.rsample())
print('discrete logistic log prob', discrete_logic.sample_with_logprob()[1])
print('discrete logistic mode', discrete_logic.mode)
print('discrete logistic std', discrete_logic.std)
print('discrete logistic entropy', discrete_logic.entropy())
_discrete_logic = DiscreteLogistic(
torch.rand(2, 5, requires_grad=True) * 2 - 1,
torch.rand(2, 5, requires_grad=True),
[5, 9, 17, 33, 65],
)
print('discrete logistic kl', kl_divergence(discrete_logic, _discrete_logic))
print('-' * 50)
mix = MixDistribution([onehot, mixture, normal, discrete_logic])
print('mix batch shape', mix.batch_shape)
print('mix event shape', mix.event_shape)
print('mix sample', mix.sample())
print('mix rsample', mix.rsample())
print('mix log prob', mix.sample_with_logprob()[1])
print('mix mode', mix.mode)
print('mix std', mix.std)
print('mix entropy', mix.entropy())
_mix = MixDistribution([_onehot, _mixture, _normal, _discrete_logic])
print('mix kl', kl_divergence(mix, _mix))