''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.
    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
import numpy as np
from typing import Tuple
import torch
from torch.functional import F
from torch.distributions import constraints
from torch.distributions import Normal, Categorical, OneHotCategorical
import pyro
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.kl import register_kl, kl_divergence
from itertools import groupby
[docs]
def all_equal(iterable):
    ''' Define a function to check if all elements in an iterable are equal'''
    g = groupby(iterable)
    return next(g, True) and not next(g, False) 
[docs]
def exportable_broadcast(tensor1 : torch.Tensor, tensor2 : torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    ''' Broadcast tensors to the same shape using onnx exportable operators '''
    if len(tensor1.shape) < len(tensor2.shape):
        tensor2, tensor1 = exportable_broadcast(tensor2, tensor1)
    else:
        shape1 = tensor1.shape
        shape2 = tensor2.shape
        if len(shape1) == len(shape2):
            final_shape = [max(s1, s2) for s1, s2 in zip(shape1, shape2)]
            tensor1 = tensor1.expand(*final_shape)
            tensor2 = tensor2.expand(*final_shape)
        else:
            tensor2 = tensor2.expand(*shape1)
    return tensor1, tensor2 
[docs]
class ReviveDistributionMixin:
    '''Define revive distribution API'''
    @property
    def mode(self,):
        '''return the most likely sample of the distributions'''
        raise NotImplementedError 
    @property
    def std(self):
        '''return the standard deviation of the distributions'''
        raise NotImplementedError
[docs]
    def sample_with_logprob(self, sample_shape=torch.Size()):
        sample = self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
        return sample, self.log_prob(sample) 
 
[docs]
class ReviveDistribution(pyro.distributions.TorchDistribution, ReviveDistributionMixin):
    pass 
[docs]
class ExportableNormal(Normal):
    def __init__(self, loc, scale, validate_args):
        """ Use exportable_broadcast() aligns the shape of the tensor """
        self.loc, self.scale = exportable_broadcast(loc, scale)
        batch_shape = self.loc.size()
        super(Normal, self).__init__(batch_shape, validate_args=validate_args) 
[docs]
class ExportableCategorical(Categorical):
[docs]
    def log_prob(self, value):
        """ Use exportable_broadcast() aligns the shape of the tensor """
        if self._validate_args:
            self._validate_sample(value)
        value = value.long().unsqueeze(-1)
        value, log_pmf = exportable_broadcast(value, self.logits)
        value = value[..., :1]
        return log_pmf.gather(-1, value).squeeze(-1) 
 
[docs]
class DiagnalNormal(ReviveDistribution):
    def __init__(self, loc, scale, validate_args=False):
        self.base_dist = ExportableNormal(loc, scale, validate_args)
        batch_shape = torch.Size(loc.shape[:-1])
        event_shape = torch.Size([loc.shape[-1]])
        super(DiagnalNormal, self).__init__(batch_shape, event_shape, validate_args)
[docs]
    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape) 
[docs]
    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape) 
[docs]
    def log_prob(self, sample):
        log_prob = self.base_dist.log_prob(sample)
        return torch.sum(log_prob, dim=-1) 
[docs]
    def entropy(self):
        entropy = self.base_dist.entropy()
        return torch.sum(entropy, dim=-1) 
[docs]
    def shift(self, mu_shift):
        '''shift the distribution, useful in local mode transition'''
        return DiagnalNormal(self.base_dist.loc + mu_shift, self.base_dist.scale) 
    @property
    def mode(self):
        return self.base_dist.mean
    @property
    def std(self):
        return self.base_dist.scale 
[docs]
class DiscreteLogistic(ReviveDistribution):
    r"""
    Model discrete variable with Logistic distribution, inspired from:
    https://github.com/openai/vdvae/blob/main/vae_helpers.py
    As far as I know, the trick was proposed in:
    Salimans, Tim, Andrej Karpathy, Xi Chen, and Diederik P. Kingma
    "Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications." arXiv preprint arXiv:1701.05517 (2017).
    :param loc: Location parameter, assert it have been normalized to [-1, 1]
    :param scale: Scale parameter.
    :param num: Number of possible value for each dimension.
    """
    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
    support = constraints.real
    has_rsample = True
    def __init__(self, loc, scale, num, *, validate_args=False):
        self.loc, self.scale = exportable_broadcast(loc, scale)
        self.num = torch.tensor(num).to(loc)
        batch_shape = torch.Size(loc.shape[:-1])
        event_shape = torch.Size([loc.shape[-1]])
        super(DiscreteLogistic, self).__init__(batch_shape, event_shape, validate_args)
[docs]
    def log_prob(self, value):
        mid = value - self.loc
        plus = (mid + 1 / (self.num - 1)) / self.scale
        minus = (mid - 1 / (self.num - 1)) / self.scale
        prob = torch.sigmoid(plus) - torch.sigmoid(minus)
        log_prob_left_edge = plus - F.softplus(plus)
        log_prob_right_edge = - F.softplus(minus)
        z = mid / self.scale
        log_prob_extreme = z - torch.log(self.scale) - 2 * F.softplus(z)
        return torch.where(value < - 0.999,
                           log_prob_left_edge,
                           torch.where(value > 0.999,
                                       log_prob_right_edge,
                                       torch.where(prob > 1e-5,
                                                   torch.log(prob + 1e-5),
                                                   log_prob_extreme))).sum(dim=-1) 
    
[docs]
    @torch.no_grad()
    def sample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        u = self.loc.new_empty(shape).uniform_()
        value = self.icdf(u)
        round_value = self.round(value)
        return torch.clamp(round_value, -1, 1) 
[docs]
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        u = self.loc.new_empty(shape).uniform_()
        value = self.icdf(u)
        round_value = self.round(value)
        return torch.clamp(round_value, -1, 1) + value - value.detach() 
[docs]
    def cdf(self, value):
        z = (value - self.loc) / self.scale
        return torch.sigmoid(z) 
[docs]
    def icdf(self, value):
        return self.loc + self.scale * torch.logit(value, eps=1e-5) 
[docs]
    def round(self, value):
        value = (value + 1) / 2 * (self.num - 1)
        return torch.round(value) / (self.num - 1) * 2 - 1 
    @property
    def mode(self):
        return torch.clamp(self.round(self.loc), -1, 1) + self.loc - self.loc.detach()
    @property
    def std(self):
        return self.scale * np.pi / 3 ** 0.5
[docs]
    def entropy(self):
        return torch.sum(torch.log(self.scale) + 2, dim=-1) 
 
[docs]
class Onehot(OneHotCategorical, TorchDistributionMixin, ReviveDistributionMixin):
    """Differentiable Onehot Distribution"""
    has_rsample = True
    _validate_args = False
    def __init__(self, logits=None, validate_args=False):
        self._categorical = ExportableCategorical(logits=logits, validate_args=False)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
[docs]
    def rsample(self, sample_shape=torch.Size()):
        # Implement straight-through estimator
        # Bengio et.al. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation 
        sample = self.sample(sample_shape)
        return sample + self.probs - self.probs.detach() 
    
    @property
    def mode(self):
        index = torch.argmax(self.logits, dim=-1)
        num_classes = self.event_shape[0]
        if torch.is_tensor(num_classes):
            num_classes = num_classes.item()
        sample = F.one_hot(index, num_classes)
        return sample + self.probs - self.probs.detach()
    @property
    def std(self):
        return self.variance 
[docs]
class GaussianMixture(pyro.distributions.MixtureOfDiagNormals, ReviveDistributionMixin):
    def __init__(self, locs, coord_scale, component_logits):
        self.batch_mode = (locs.dim() > 2)
        assert(coord_scale.shape == locs.shape)
        assert(self.batch_mode or locs.dim() == 2), \
            
"The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or B x K x D if doing batches)"
        if not self.batch_mode:
            assert(coord_scale.dim() == 2), \
                
"The coord_scale parameter in MixtureOfDiagNormals should be K x D dimensional"
            assert(component_logits.dim() == 1), \
                
"The component_logits parameter in MixtureOfDiagNormals should be K dimensional"
            assert(component_logits.size(-1) == locs.size(-2))
            batch_shape = ()
        else:
            assert(coord_scale.dim() > 2), \
                
"The coord_scale parameter in MixtureOfDiagNormals should be B x K x D dimensional"
            assert(component_logits.dim() > 1), \
                
"The component_logits parameter in MixtureOfDiagNormals should be B x K dimensional"
            assert(component_logits.size(-1) == locs.size(-2))
            batch_shape = tuple(locs.shape[:-2])
        self.locs = locs
        self.coord_scale = coord_scale
        self.component_logits = component_logits
        self.dim = locs.size(-1)
        self.categorical = ExportableCategorical(logits=component_logits)
        self.probs = self.categorical.probs
        ReviveDistribution.__init__(self, batch_shape=torch.Size(batch_shape), event_shape=torch.Size((self.dim,)))
    @property
    def mode(self):
        # NOTE: this is only an approximate mode
        which = self.categorical.logits.max(dim=-1)[1]
        which = which.unsqueeze(dim=-1).unsqueeze(dim=-1)
        which_expand = which.expand(tuple(which.shape[:-1] + (self.locs.shape[-1],)))
        loc = torch.gather(self.locs, -2, which_expand).squeeze(-2)
        return loc
    @property
    def std(self):
        p = self.categorical.probs
        return torch.sum(self.coord_scale * p.unsqueeze(-1), dim=-2)
[docs]
    def shift(self, mu_shift):
        '''shift the distribution, useful in local mode transition'''
        return GaussianMixture(self.locs + mu_shift.unsqueeze(dim=-2), self.coord_scale, self.component_logits) 
[docs]
    def entropy(self):
        p = self.categorical.probs
        normal = DiagnalNormal(self.locs, self.coord_scale)
        entropy = normal.entropy()
        return torch.sum(p * entropy, dim=-1) 
 
[docs]
class MixDistribution(ReviveDistribution):
    """Collection of multiple distributions"""
    arg_constraints = {}
    
    def __init__(self, dists):
        super().__init__()
        assert all_equal([dist.batch_shape for dist in dists]), "the batch shape of all distributions should be equal"
        assert all_equal([len(dist.event_shape) == 1 for dist in dists]), "the event shape of all distributions should have length 1"
        self.dists = dists
        self.sizes = [dist.event_shape[0] for dist in self.dists]
        batch_shape = self.dists[0].batch_shape
        event_shape = torch.Size((sum(self.sizes),))
        super(MixDistribution, self).__init__(batch_shape, event_shape)
[docs]
    def sample(self, num=torch.Size()):
        samples = [dist.sample(num) for dist in self.dists]
        return torch.cat(samples, dim=-1) 
[docs]
    def rsample(self, num=torch.Size()): 
        samples = [dist.rsample(num) for dist in self.dists]
        return torch.cat(samples, dim=-1) 
[docs]
    def entropy(self):
        return sum([dist.entropy() for dist in self.dists])   
[docs]
    def log_prob(self, x):
        if type(x) == list:
            return [self.dists[i].log_prob(x[i]) for i in range(len(x))]
        # manually split the tensor
        x = torch.split(x, self.sizes, dim=-1)
        return sum([self.dists[i].log_prob(x[i]) for i in range(len(x))]) 
    @property
    def mode(self):
        modes = [dist.mode for dist in self.dists]
        return torch.cat(modes, dim=-1)
    @property
    def std(self):
        stds = [dist.std for dist in self.dists]
        return torch.cat(stds, dim=-1)
    @property
    def dists_type(self):
        return [type(dist) for dist in self.dists]
[docs]
    def shift(self, mu_shift):
        '''shift the distribution, useful in local mode transition'''
        assert all([type(dist) in [DiagnalNormal, GaussianMixture] for dist in self.dists]), \
            
"all the distributions should have method `shift`"
        return MixDistribution([dist.shift(mu_shift) for dist in self.dists]) 
 
@register_kl(DiagnalNormal, DiagnalNormal)
def _kl_diagnalnormal_diagnalnormal(p : DiagnalNormal, q : DiagnalNormal):
    kl = kl_divergence(p.base_dist, q.base_dist)
    kl = torch.sum(kl, dim=-1)
    return kl
@register_kl(Onehot, Onehot)
def _kl_onehot_onehot(p : Onehot, q : Onehot):
    p_probs = p.probs.clamp_min(1e-30)
    q_probs = q.probs.clamp_min(1e-30)
    kl = (p_probs * (torch.log(p_probs) - torch.log(q_probs))).sum(dim=-1)
    return kl
@register_kl(GaussianMixture, GaussianMixture)
def _kl_gmm_gmm(p : GaussianMixture, q : GaussianMixture):
    samples = p.rsample()
    log_p = p.log_prob(samples)
    log_q = q.log_prob(samples)
    return log_p - log_q
@register_kl(MixDistribution, MixDistribution)
def _kl_mix_mix(p : MixDistribution, q : MixDistribution):
    assert all([type(_p) == type(_q) for _p, _q in zip(p.dists, q.dists)])
    kl = 0
    for _p, _q in zip(p.dists, q.dists):
        kl = kl + kl_divergence(_p, _q)
    return kl
@register_kl(DiscreteLogistic, DiscreteLogistic)
def _kl_discrete_logistic_discrete_logistic(p : DiscreteLogistic, q : DiscreteLogistic):
    assert torch.all(p.num == q.num)
    # NOTE: Cannot compute the kl divergence in the analysitical form, use 100 samples to estimate.
    samples = p.sample((100,))
    p_log_prob = p.log_prob(samples)
    q_log_prob = q.log_prob(samples)
    return torch.mean(p_log_prob - q_log_prob, dim=0)
     
if __name__ == '__main__':
    print('-' * 50)
    onehot = Onehot(torch.rand(2, 10, requires_grad=True))
    print('onehot batch shape', onehot.batch_shape)
    print('onehot event shape', onehot.event_shape)
    print('onehot sample', onehot.sample())
    print('onehot rsample', onehot.rsample())
    print('onehot log prob', onehot.sample_with_logprob()[1])
    print('onehot mode', onehot.mode)
    print('onehot std', onehot.std)
    print('onehot entropy', onehot.entropy())
    _onehot = Onehot(torch.rand(2, 10, requires_grad=True))
    print('onehot kl', kl_divergence(onehot, _onehot))
    print('-' * 50)
    mixture = GaussianMixture(
        torch.rand(2, 6, 4, requires_grad=True), 
        torch.rand(2, 6, 4, requires_grad=True),
        torch.rand(2, 6, requires_grad=True), 
    )
    print('gmm batch shape', mixture.batch_shape)
    print('gmm event shape', mixture.event_shape)
    print('gmm sample', mixture.sample())
    print('gmm rsample', mixture.rsample())
    print('gmm log prob', mixture.sample_with_logprob()[1])
    print('gmm mode', mixture.mode)
    print('gmm std', mixture.std)
    print('gmm entropy', mixture.entropy())
    _mixture = GaussianMixture(
        torch.rand(2, 6, 4, requires_grad=True), 
        torch.rand(2, 6, 4, requires_grad=True),
        torch.rand(2, 6, requires_grad=True), 
    )
    print('gmm kl', kl_divergence(mixture, _mixture))
    print('-' * 50)
    normal = DiagnalNormal(
        torch.rand(2, 5, requires_grad=True), 
        torch.rand(2, 5, requires_grad=True)
    )
    print('normal batch shape', normal.batch_shape)
    print('normal event shape', normal.event_shape)
    print('normal sample', normal.sample())
    print('normal rsample', normal.rsample())
    print('normal log prob', normal.sample_with_logprob()[1])
    print('normal mode', normal.mode)
    print('normal std', normal.std)
    print('normal entropy', normal.entropy())
    _normal = DiagnalNormal(
        torch.rand(2, 5, requires_grad=True), 
        torch.rand(2, 5, requires_grad=True)
    )
    print('normal kl', kl_divergence(normal, _normal))
    print('-' * 50)
    discrete_logic = DiscreteLogistic(
        torch.rand(2, 5, requires_grad=True) * 2 - 1, 
        torch.rand(2, 5, requires_grad=True),
        [5, 9, 17, 33, 65],
    )
    print('discrete logistic batch shape', discrete_logic.batch_shape)
    print('discrete logistic event shape', discrete_logic.event_shape)
    print('discrete logistic sample', discrete_logic.sample())
    print('discrete logistic rsample', discrete_logic.rsample())
    print('discrete logistic log prob', discrete_logic.sample_with_logprob()[1])
    print('discrete logistic mode', discrete_logic.mode)
    print('discrete logistic std', discrete_logic.std)
    print('discrete logistic entropy', discrete_logic.entropy())
    _discrete_logic = DiscreteLogistic(
        torch.rand(2, 5, requires_grad=True) * 2 - 1, 
        torch.rand(2, 5, requires_grad=True),
        [5, 9, 17, 33, 65],
    )
    print('discrete logistic kl', kl_divergence(discrete_logic, _discrete_logic))
    print('-' * 50)
    mix = MixDistribution([onehot, mixture, normal, discrete_logic])
    print('mix batch shape', mix.batch_shape)
    print('mix event shape', mix.event_shape)
    print('mix sample', mix.sample())
    print('mix rsample', mix.rsample())
    print('mix log prob', mix.sample_with_logprob()[1])
    print('mix mode', mix.mode)
    print('mix std', mix.std)
    print('mix entropy', mix.entropy())
    _mix = MixDistribution([_onehot, _mixture, _normal, _discrete_logic])
    print('mix kl', kl_divergence(mix, _mix))