Comments (9)
@Balandat looking through our other RandomVariables - it looks like most of them don't use LazyTensors or any additional functionality. So we should probably only write custom MVN and MultitaskMVN distributions
from gpytorch.
I'm working on some early draft code to include priors. At least for the basic purpose of automatically adding the log_prob term of each parameter's prior (if specified) to the marginal log likelihood, the API implemented by the basic Distribution
class seems adequate.
from gpytorch.
Yeah, I think for the most part we'll be able to use mostly PyTorch distributions. I think the only custom distribution we need is a MultivariateNormal distribution (no surprises here...)
from gpytorch.
pytorch has a MultivariateNormal distribution: https://pytorch.org/docs/stable/distributions.html?highlight=multivariate#torch.distributions.multivariate_normal.MultivariateNormal
from gpytorch.
Our MultivariateNormal needs some custom functionality. Mostly it needs to be able to handle LazyVaraible covariance matrices without explicitly constructing the covariance matrix.
from gpytorch.
So I was toying around with this a little, and I'm not sure how to best handle this.
Regarding the MVN:
When there are no LazyTensors are involved, we can basically use the torch.distributions implementation (resulting in some changes to the api if we don't want to wrap the torch.distribution one).
If there are LazyTensors passed as args, we have to re-implement some things, but we'd want to re-use a lot of the basic functionality of torch.distributions.
Not sure what the implications of this for autograd etc. are -- we'll have to make sure setting things up based on torch.distributions creates issues in this regard.
What do you think about doing something like the following (we can also have a class for Distributions specifying the generic API if we want):
import torch
from torch.distributions import MultivariateNormal as TMultivariateNormal
from torch.distributions.multivariate_normal import _batch_mv
from ..lazy import LazyTensor
class MultivariateNormal(TMultivariateNormal):
"""
Constructs a multivariate Normal random variable, based on mean and covariance
Can be multivariate, or a batch of multivariate gaussians
Passing a vector mean corresponds to a multivariate Gaussian
Passing a matrix mean corresponds to a batch of multivariate Gaussians
Args:
mean (Tensor): vector n or matrix b x n mean of Gaussian distribution
covar (Tensor): matrix n x n or batch matrix b x n x n covariance of
Gaussian distribution
"""
def __init__(self, mean, covariance_matrix, validate_args=False):
if not isinstance(mean, LazyTensor) and not isinstance(
covariance_matrix, LazyTensor
):
return TMultivariateNormal(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
else:
return _LazyMultivariateNormal(
mean=mean, covariance_matrix=covariance_matrix, validate_args=validate_args
)
def __add__(self, other):
if isinstance(other, MultivariateNormal):
return MultivariateNormal(
mean=self._mean + other.mean,
covariance_matrix=self.covariance_matrix + other.covariance_matrix,
)
elif isinstance(other, int) or isinstance(other, float):
return MultivariateNormal(self.mean + other, self.covariance_matrix)
else:
raise RuntimeError("Unsupported type for addition w/ MultivariateNormal")
def __radd__(self, other):
if other == 0:
return self
return self.__add__(other)
def __div__(self, other):
return self.__mul__(1. / other)
def __mul__(self, other):
if not isinstance(other, int) and not isinstance(other, float):
raise RuntimeError("Can only multiply by scalars")
return self.__class__(
loc=self.mean * other,
covariance_matrix=self.covariance_matrix * (other ** 2),
)
class _LazyMultivariateNormal(MultivariateNormal):
def __init__(self, mean, covariance_matrix, validate_args=False):
# add validation here...
self.loc = mean
self._covar = covariance_matrix
@property
def scale_tril(self):
raise NotImplementedError()
@property
def covariance_matrix(self):
return self._covar
@property
def precision_matrix(self):
raise NotImplementedError()
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = self.loc.new_empty(shape).normal_()
# this will fail, rewrite using LazyTensor code
return self.loc + _batch_mv(self._covar.root_decomposition(), eps)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diff = value - self.loc
# re-write this for LazyTensors
# M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
# half_log_det = _batch_diag(self._unbroadcasted_scale_tril).log().sum(-1)
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
def entropy(self):
# re-write this for LazyTensors
# half_log_det = _batch_diag(self._unbroadcasted_scale_tril).log().sum(-1)
H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)
from gpytorch.
@Balandat that looks great! In theory I imagine it should all work right away with autograd (but there's probably some bugs we'll find along the way)
I'm imagining we probably won't have to do anything for the other gpytorch distributions, but it probably wouldn't hurt to create a thin gpytorch wrapper around them (just in case we want to add any specific gpytorch functionality)? Or should we only have a sub-classed MVN and then just use pytorch distributions for everything else?
from gpytorch.
@gpleiss I haven't worked much with the other random variables, are they supposed to work with LazyTensors? If yes then we'll need to wrap everything in a similar fashion, in which case a gpytorch Distribution class would make sense.
Using the distributions more or less directly will simplify things / reduce code bloat significantly, but it does entail a bunch of breaking changes. But better to get this in (together with the LazyTensor rename) before the first beta release.
from gpytorch.
Closed by #288
from gpytorch.
Related Issues (20)
- [Bug] Problems in the normalization and standardization of data HOT 1
- [Bug] Standardization of the output and inverse transform of standard deviation HOT 3
- [Docs] `get_fantasy_model` - are posterior covariances computed from scratch or using efficient cache updates? HOT 1
- [Bug] Bug in GP Regression with KeOps Kernels HOT 1
- [Bug] Extreme oscillation in loss
- [Bug] Extreme loss oscillation during training
- [Docs] Missing docs for HammingIMQKernel
- [Feature Request] Generic typing for scale kernels
- [Docs] Making sense of batch processing and tasks
- [Feature Request] Is it possible to work with Changepoint kernels?
- Nesting GPs; using the sufficient statistics from one GP as sufficient stats in another GP - variance goes to zero
- Label flattening fails with custom mean function from another GP HOT 1
- [Docs] Unexpected behavior setting kernel priors HOT 1
- [Feature Request] Allow `kwargs` to be passed to `ExactMarginalLogLikelihood.forward()` HOT 1
- [Bug] CUDA out of memory, strange numbers HOT 1
- [Docs] qKnowledgeGradient cpu usage HOT 1
- [Bug] Erroneous detaching with (custom?) mean
- [Bug] Multitask-ExactGPs seem to not use mBCG algorithm as Singletask-ExactGPs do
- [Feature Request] Choose which dimentions to differenciate with respect to in derivative multitask GPs
- [Bug] Error in tutorials and derivative GPs HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gpytorch.