# Source code for pyro.distributions.lkj

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import warnings

import torch

from pyro.distributions.torch import LKJCholesky, TransformedDistribution
from pyro.distributions.transforms.cholesky import CorrMatrixCholeskyTransform

from . import constraints

[docs]class LKJCorrCholesky(LKJCholesky):  # DEPRECATED
def __init__(self, d, eta, validate_args=None):
warnings.warn(
"class LKJCorrCholesky(d, eta, validate_args=None) is deprecated "
"in favor of LKJCholesky(dim, concentration, validate_args=None).",
FutureWarning,
)
super().__init__(d, concentration=eta, validate_args=validate_args)

[docs]class LKJ(TransformedDistribution):
r"""
LKJ distribution for correlation matrices. The distribution is controlled by concentration
parameter :math:\eta to make the probability of the correlation matrix :math:M propotional
to :math:\det(M)^{\eta - 1}. Because of that, when concentration == 1, we have a
uniform distribution over correlation matrices.

When concentration > 1, the distribution favors samples with large large determinent. This
is useful when we know a priori that the underlying variables are not correlated.
When concentration < 1, the distribution favors samples with small determinent. This is
useful when we know a priori that some underlying variables are correlated.

:param int dimension: dimension of the matrices
:param ndarray concentration: concentration/shape parameter of the
distribution (often referred to as eta)

**References**

 Generating random correlation matrices based on vines and extended onion method,
Daniel Lewandowski, Dorota Kurowicka, Harry Joe
"""
arg_constraints = {"concentration": constraints.positive}
support = constraints.corr_matrix

def __init__(self, dim, concentration=1.0, validate_args=None):
base_dist = LKJCholesky(dim, concentration)
self.dim, self.concentration = base_dist.dim, base_dist.concentration
super(LKJ, self).__init__(
base_dist, CorrMatrixCholeskyTransform().inv, validate_args=validate_args
)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LKJCholesky, _instance)
return super(LKJCholesky, self).expand(batch_shape, _instance=new)

@property
def mean(self):