Git Product home page Git Product logo

Comments (9)

gpleiss avatar gpleiss commented on June 12, 2024 1

@Balandat looking through our other RandomVariables - it looks like most of them don't use LazyTensors or any additional functionality. So we should probably only write custom MVN and MultitaskMVN distributions

from gpytorch.

Balandat avatar Balandat commented on June 12, 2024

I'm working on some early draft code to include priors. At least for the basic purpose of automatically adding the log_prob term of each parameter's prior (if specified) to the marginal log likelihood, the API implemented by the basic Distribution class seems adequate.

from gpytorch.

gpleiss avatar gpleiss commented on June 12, 2024

Yeah, I think for the most part we'll be able to use mostly PyTorch distributions. I think the only custom distribution we need is a MultivariateNormal distribution (no surprises here...)

from gpytorch.

Balandat avatar Balandat commented on June 12, 2024

pytorch has a MultivariateNormal distribution: https://pytorch.org/docs/stable/distributions.html?highlight=multivariate#torch.distributions.multivariate_normal.MultivariateNormal

from gpytorch.

gpleiss avatar gpleiss commented on June 12, 2024

Our MultivariateNormal needs some custom functionality. Mostly it needs to be able to handle LazyVaraible covariance matrices without explicitly constructing the covariance matrix.

from gpytorch.

Balandat avatar Balandat commented on June 12, 2024

So I was toying around with this a little, and I'm not sure how to best handle this.

Regarding the MVN:
When there are no LazyTensors are involved, we can basically use the torch.distributions implementation (resulting in some changes to the api if we don't want to wrap the torch.distribution one).
If there are LazyTensors passed as args, we have to re-implement some things, but we'd want to re-use a lot of the basic functionality of torch.distributions.

Not sure what the implications of this for autograd etc. are -- we'll have to make sure setting things up based on torch.distributions creates issues in this regard.

What do you think about doing something like the following (we can also have a class for Distributions specifying the generic API if we want):

import torch
from torch.distributions import MultivariateNormal as TMultivariateNormal
from torch.distributions.multivariate_normal import _batch_mv

from ..lazy import LazyTensor


class MultivariateNormal(TMultivariateNormal):
    """
    Constructs a multivariate Normal random variable, based on mean and covariance
    Can be multivariate, or a batch of multivariate gaussians

    Passing a vector mean corresponds to a multivariate Gaussian
    Passing a matrix mean corresponds to a batch of multivariate Gaussians

    Args:
        mean (Tensor): vector n or matrix b x n mean of Gaussian distribution
        covar (Tensor): matrix n x n or batch matrix b x n x n covariance of
            Gaussian distribution
    """

    def __init__(self, mean, covariance_matrix, validate_args=False):
        if not isinstance(mean, LazyTensor) and not isinstance(
            covariance_matrix, LazyTensor
        ):
            return TMultivariateNormal(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
        else:
            return _LazyMultivariateNormal(
                mean=mean, covariance_matrix=covariance_matrix, validate_args=validate_args
            )

    def __add__(self, other):
        if isinstance(other, MultivariateNormal):
            return MultivariateNormal(
                mean=self._mean + other.mean,
                covariance_matrix=self.covariance_matrix + other.covariance_matrix,
            )
        elif isinstance(other, int) or isinstance(other, float):
            return MultivariateNormal(self.mean + other, self.covariance_matrix)
        else:
            raise RuntimeError("Unsupported type for addition w/ MultivariateNormal")

    def __radd__(self, other):
        if other == 0:
            return self
        return self.__add__(other)

    def __div__(self, other):
        return self.__mul__(1. / other)

    def __mul__(self, other):
        if not isinstance(other, int) and not isinstance(other, float):
            raise RuntimeError("Can only multiply by scalars")
        return self.__class__(
            loc=self.mean * other,
            covariance_matrix=self.covariance_matrix * (other ** 2),
        )


class _LazyMultivariateNormal(MultivariateNormal):
    def __init__(self, mean, covariance_matrix, validate_args=False):
        # add validation here...
        self.loc = mean
        self._covar = covariance_matrix

    @property
    def scale_tril(self):
        raise NotImplementedError()

    @property
    def covariance_matrix(self):
        return self._covar

    @property
    def precision_matrix(self):
        raise NotImplementedError()

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = self.loc.new_empty(shape).normal_()
        # this will fail, rewrite using LazyTensor code
        return self.loc + _batch_mv(self._covar.root_decomposition(), eps)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        diff = value - self.loc
        # re-write this for LazyTensors
        # M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
        # half_log_det = _batch_diag(self._unbroadcasted_scale_tril).log().sum(-1)
        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det

    def entropy(self):
        # re-write this for LazyTensors
        # half_log_det = _batch_diag(self._unbroadcasted_scale_tril).log().sum(-1)
        H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
        if len(self._batch_shape) == 0:
            return H
        else:
            return H.expand(self._batch_shape)

from gpytorch.

gpleiss avatar gpleiss commented on June 12, 2024

@Balandat that looks great! In theory I imagine it should all work right away with autograd (but there's probably some bugs we'll find along the way)

I'm imagining we probably won't have to do anything for the other gpytorch distributions, but it probably wouldn't hurt to create a thin gpytorch wrapper around them (just in case we want to add any specific gpytorch functionality)? Or should we only have a sub-classed MVN and then just use pytorch distributions for everything else?

from gpytorch.

Balandat avatar Balandat commented on June 12, 2024

@gpleiss I haven't worked much with the other random variables, are they supposed to work with LazyTensors? If yes then we'll need to wrap everything in a similar fashion, in which case a gpytorch Distribution class would make sense.

Using the distributions more or less directly will simplify things / reduce code bloat significantly, but it does entail a bunch of breaking changes. But better to get this in (together with the LazyTensor rename) before the first beta release.

from gpytorch.

gpleiss avatar gpleiss commented on June 12, 2024

Closed by #288

from gpytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.