Simple torch.distributions
wrappers for DL.
This module provides the following features.
-
Easy Instanitate
import torch from distrubution_extensions import NormalFactory tensor = torch.rand([256, 100, 16]) distribution = NormalFactory()(Tensor) distribution.sample() # -> Tensor[256, 100, 8]
-
Easy Independence
distribution = Normal(loc=loc, scale=scale) independent = distribution.independent(dim=1)
-
Device Conversion
device = torch.device("cuda:0") distribution = Normal(loc=loc, scale=scale) distribution = distribution.to(device=device)
-
Slicing
distribution = Normal(loc=loc, scale=scale)[:, 0, :] distribution.sample() # -> Tensor[256, 8]
-
Stop Gradient
distribution = Normal(loc=loc, scale=scale) distribution.detach()